You can get 25 to 50 points
Goal of the project is to create more thorough analysis of the chosen dataset than in the previous two smaller projects.
The dataset selection is up to you however is has come from image, text or time series domain.
Every project must include brief description of the dataset
What metrics scores you have decided to use
Try at least 3 different models
Mandatory part of every project is a summary at the end in which you summarize the most interesting insight obtained.
Result is a Jupyter Notebook with descriptions included or a PDF report + source codes.
Deadline is 20. 4. 2022
%%HTML
<script src="require.js"></script>
import plotly.io as pio
pio.renderers.default = "notebook"
import os
import pandas as pd
from enum import Enum
from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split
from gensim.parsing.preprocessing import (
preprocess_string,
strip_tags as strip_tags_gensim,
strip_punctuation as strip_punctuation_gensim,
strip_multiple_whitespaces as strip_multiple_whitespaces_gensim,
strip_numeric as strip_numeric_gensim,
remove_stopwords as remove_stopwords_gensim,
strip_short as strip_short_gensim,
stem_text as stem_text_gensim,
)
import nltk
from nltk.stem import WordNetLemmatizer
import plotly.express as px
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from wordcloud import WordCloud
import itertools
from sklearn import preprocessing
import pandas as pd
import time
import seaborn as sns
nltk.download('wordnet')
[nltk_data] Downloading package wordnet to [nltk_data] /home/usp/pro0255/nltk_data... [nltk_data] Package wordnet is already up-to-date!
True
EXPERIMENTS_SAVE_DIRECTORY = ['.', 'experiments']
DIRECTORY = ['..', 'data', 'gutenberg']
LOAD_5A3S = ["5Authors", "Sentence3" ]
LOAD_15A3S = ["15Authors", "Sentence3"]
AUTHORS_FILENAME = ['authors.csv']
DATA_FILENAME = ['data.csv']
TEXT_COLUMN = "text"
LABEL_COLUMN = "label"
BLANK_DESCRIPTION = "Nada"
PROJECT_CSV_DELIMITER = ";"
TRAIN_SIZE = 0.75
VALIDATION_SIZE = 0.10
TEST_SIZE = 0.15
RANDOM_STATE = 7
NORMALIZE_LABEL = 15000
LOG = 'log.csv'
SUMMAR = 'sum.csv'
class PreprocessingType(Enum):
Default = "Default"
Lowercase = "Lowercase"
CaseInterpunction = "CaseInterpunction"
Raw = "Raw"
Blank = BLANK_DESCRIPTION
def create_dataset_from_dataframe(dataframe):
features, target = dataframe[TEXT_COLUMN], dataframe[LABEL_COLUMN]
return create_dataset_from_Xy(features, target)
def create_dataset_from_Xy(X, y):
return tf.data.Dataset.from_tensor_slices((X, y))
def create_encoder_from_path(path):
authors = pd.read_csv(path, sep=";")
ids = authors["AuthorId"].values
encoder = preprocessing.LabelEncoder()
encoder.fit(ids)
return encoder
class CSVLogger(tf.keras.callbacks.Callback):
def __init__(self, path):
self.path = os.path.sep.join([path, LOG])
self.timetaken = time.time()
self.state = {}
def on_epoch_end(self, epoch, logs={}):
logs["time"] = time.time() - self.timetaken
self.state[epoch] = logs
def on_train_end(self, logs={}):
headers = []
for k, v in self.state.items():
headers = self.state[k].keys()
break
data = {k: self.state[k].values() for k, v in self.state.items()}
df = pd.DataFrame.from_dict(data, orient="index")
df.columns = headers
df.to_csv(self.path, sep=';')
class Visualizer:
def __init__(self):
pass
def create_max_min_mean_len(self, tuples):
#d = vis.create_max_min_mean_len([("test", test), ("feste", test)]) test
res = pd.DataFrame()
for name, data in tuples:
x = data.copy()
x['len'] = x[TEXT_COLUMN].apply(len)
x = x.groupby(LABEL_COLUMN).len.agg(['mean', 'min', 'max'])
together = pd.DataFrame.from_dict(
{"together":
{
'mean': np.mean(x['mean']),
'min': np.min(x['min']),
'max': np.max(x['max'])
}
},
orient="index"
)
x = pd.concat([together, x])
x = x.reset_index()
x = pd.melt(x, id_vars=['index'], var_name='value_type', value_name='value')
x['df_type'] = name
res = pd.concat([res, x])
return res
def show_mean(self, dataframe):
return self.show_type(dataframe, 'mean')
def show_max(self, dataframe):
return self.show_type(dataframe, 'max')
def show_min(self, dataframe):
return self.show_type(dataframe, 'min')
def show_type(self, dataframe, spe_type):
return px.bar(dataframe[dataframe.value_type == spe_type], x="df_type", y="value", color='index', barmode='group')
def seq_dist(self, dataframe):
x = dataframe.copy()
x['len'] = x[TEXT_COLUMN].apply(len)
return px.histogram(x, x='len', color='label', title='')
def create_all_words(self, dataframe):
x = dataframe.copy()
all_words = list(itertools.chain.from_iterable([sentence.split(' ') for sentence in x[TEXT_COLUMN]]))
dist = nltk.FreqDist(all_words)
return dist
def generate_top_words(self, dataframe):
x = dataframe.copy()
res = {}
for current_label in np.unique(test.label.values):
subframe = x[x.label == current_label]
res[current_label] = self.show_top_words(subframe)
res['all'] = self.show_top_words(x)
return res
def show_top_words(self, dataframe, n=10):
dist = self.create_all_words(dataframe)
df = pd.DataFrame.from_dict(dict(dist), orient="index").reset_index()
df.columns = ['word', 'freq']
df = df.sort_values(by='freq', ascending=False)
first_n = df.iloc[0:n, :]
return px.bar(first_n, x="word", y="freq", color='word', title=f"{n} most freq words")
def show_wordcloud(self, wordcloud):
plt.figure(figsize=[10, 10])
plt.axis("off")
x = plt.imshow(wordcloud, interpolation="bilinear")
return x
def generate_wordclouds(self, dataframe):
#result = vis.generate_wordclouds(test) test
x = dataframe.copy()
res = {}
for current_label in np.unique(test.label.values):
subframe = x[x.label == current_label]
res[current_label] = self.wordcloud(subframe)
res['all'] = self.wordcloud(x)
return res
def wordcloud(self, dataframe, max_words=100):
x = dataframe.copy()
current_text = " ".join(x.text.values)
wordcloud = WordCloud(max_font_size=50, max_words=100, background_color="white").generate(current_text)
return wordcloud
def prediction_to_labels(y_pred):
y_pred = np.argmax(y_pred, axis=1)
return y_pred
BLACKLIST = [
'CHAPTER'
]
class TextPreprocessor:
def __init__(self) -> None:
self.strip_short_default = self.create_strip_short_method(3)
self.lemma_text = self.create_lemma_text()
def strip_tags(self, text):
return strip_tags_gensim(text)
def strip_upper_words(self, text):
return [word for word in text.split(' ') if word.upper() != word]
def remove_when_blacklisted(self, text):
current_text = set(text.split(' '))
blacklist = set(BLACKLIST)
l = len(current_text.intersection(blacklist))
if l > 0:
return ''
return text
def strip_punctuation(self, text):
return strip_punctuation_gensim(text)
def strip_multiple_whitespaces(self, text):
return strip_multiple_whitespaces_gensim(text)
def strip_numeric(self, text):
return strip_numeric_gensim(text)
def strip_stopwords(self, text):
return remove_stopwords_gensim(text)
def strip_short(self, text, minsize=3):
return strip_short_gensim(text, minsize)
def strip_short(self, text):
return strip_short_gensim(text)
def create_strip_short_method(self, minsize=3):
# TODO: fix
print(f"Creating shorting method with min = {minsize}")
def strip_short(text, minsize=minsize):
return strip_short_gensim(text, minsize)
return strip_short
def stem_text(self, text):
return stem_text_gensim(text)
def to_lowercase(self, text):
return text.lower()
def create_lemma_text(self):
instance = WordNetLemmatizer()
print(f"Creating lemma method with instance {instance}")
def lemma_text(text):
word_list = nltk.word_tokenize(text)
return " ".join([instance.lemmatize(word) for word in word_list])
return lemma_text
def create_preprocess_string_func(self, filters, tokenized=False):
def preprocess_func(text):
result = preprocess_string(text, filters)
return result if tokenized else " ".join(result)
return preprocess_func
def default_preprocessing(self):
return self.create_preprocess_string_func(
[
self.remove_when_blacklisted,
self.to_lowercase,
self.strip_punctuation,
self.strip_tags,
self.strip_multiple_whitespaces,
self.strip_numeric,
self.strip_stopwords,
self.strip_short,
self.lemma_text,
]
)
def default_lowerinterpunction(self):
return self.create_preprocess_string_func(
[
self.remove_when_blacklisted,
self.to_lowercase,
self.strip_punctuation,
self.strip_multiple_whitespaces,
self.strip_numeric
]
)
class PreprocessingFactory:
def __init__(self) -> None:
self.preprocessor = TextPreprocessor()
self.build_dic()
def build_dic(self):
self.dic = {
PreprocessingType.Default: self.preprocessor.default_preprocessing(),
PreprocessingType.Lowercase: self.preprocessor.create_preprocess_string_func(
[self.preprocessor.to_lowercase]
),
PreprocessingType.Raw: lambda x: x,
PreprocessingType.CaseInterpunction: self.preprocessor.default_lowerinterpunction(),
PreprocessingType.Blank: None,
}
def create(self, preprocessing_type):
return self.dic[preprocessing_type]
def get_load_path(directory, combination, filename):
return os.path.sep.join(directory + combination + filename)
def get_load_path_53(filename=DATA_FILENAME):
return get_load_path(DIRECTORY, LOAD_5A3S, filename)
def get_load_path_153(filename=DATA_FILENAME):
return get_load_path(DIRECTORY, LOAD_15A3S, filename)
def load_dataset_from_path_with_normalization(path, normalize=None, preprocessing_type=None):
factory = PreprocessingFactory()
normalize_final = None
if normalize is not None:
print('Specified normalize method')
normalize_final = normalize
else:
print(f'Specified type {preprocessing_type.value}')
if preprocessing_type is None:
normalize_final = factory.create(PreprocessingType.Default)
else:
normalize_final = factory.create(preprocessing_type)
dataset = load_dataset_from_path(path)
dataset[TEXT_COLUMN] = dataset[TEXT_COLUMN].apply(normalize_final)
return dataset
def load_dataset_from_path(path):
dataset = pd.read_csv(path, sep=PROJECT_CSV_DELIMITER, header=None)
dataset.columns = [TEXT_COLUMN, LABEL_COLUMN]
return dataset
def normalize_dataframe_to_size(dataframe, size):
all_labels = dataframe[LABEL_COLUMN].unique()
new_dataframe = pd.DataFrame()
for label in all_labels:
selected_dataframe = dataframe[dataframe.label == label].sample(size, random_state=RANDOM_STATE)
new_dataframe = pd.concat([new_dataframe, selected_dataframe])
return shuffle(new_dataframe, random_state=RANDOM_STATE)
def split_dataframe_to_train_test_valid(dataframe, test_size=TEST_SIZE, valid_size=VALIDATION_SIZE):
features, target = dataframe[TEXT_COLUMN], dataframe[LABEL_COLUMN]
X_train, X_test, y_train, y_test = train_test_split(features, target, test_size=test_size, random_state=RANDOM_STATE)
X_train, X_valid, y_train, y_valid = train_test_split(X_train, y_train, test_size=valid_size, random_state=RANDOM_STATE)
print(f"Train {X_train.shape}")
print(f"Valid {X_valid.shape}")
print(f"Test {X_test.shape}")
return X_train, X_valid, X_test, y_train, y_valid, y_test
test = load_dataset_from_path_with_normalization(get_load_path_53(), None, PreprocessingType.Raw)
Creating shorting method with min = 3 Creating lemma method with instance <WordNetLemmatizer> Specified type Raw
test = normalize_dataframe_to_size(test, NORMALIZE_LABEL)
X_train, X_valid, X_test, y_train, y_valid, y_test = split_dataframe_to_train_test_valid(test)
Train (57375,) Valid (6375,) Test (11250,)
Datová sada je vyrobna vlastní sílou. Jde o textová data, která se parují s autorem těchto dat. Z hlediska textové stylu tento text můžeme zařadit do uměleckého stylu.
Motivací proč datová sada byla vytvořena je pokusit se identifikovat autora textu. Tento problém je určení autorství textu. Pakliže se na to budeme dívat skrz oči datového analytika, pak je to převlečený klasifikační problém, kde se snažíme předikonvat n tříd. Zde jako třídu myslíme autora textu.
Projekt Gutenberg je dobrovolnická snaha digitalizovat, archivovat a distribuovat kulturní díla. Byl založen v roce 1971 a je nejstarší digitální knihovnou. Většina děl jsou plné texty knih se statusem volného díla.
Jak již bylo zmíněno, datová sada byla vytvořena částečně vlastní sílou a částečně s výpomocí některých veřejně dostupných knihoven nalezených na internetu.
V mém případě jsem využil knihovnu do R (https://cran.r-project.org/web/packages/gutenbergr/vignettes/intro.html), která poskytuje jednoduché API pro stáhnutí všech děl z Projektu Gutenberg.
Před samotným stažením byla byla všechna dostupná díla vyfiltrovaná dle anglického jazyka, a tak aby měla autora. Všechna tato díla byla stažena ve formátu json, tak aby z nich pak šlo jednoduše vytvořit potřebný dataset. Příklad takové json souboru lze najít v adresáři example.
Poté co se mi podařilo úspěšně stáhnut všechna díla byly vytvořeny datové sady. Vždy byl specifikován počet autorů a velikost textové části. Pro představu pokud dané dílo bylo od autora námi požadovaného, pak celé dílo bylo rozsekáno na textové části o n větách. Kde věta byla ukončena dle typických znaků jako jsou {., !, ?}. Takto byly vytvořeny záznamy, které reprezentují strukturu datové sady.
Struktura datové sady:
V tomto projektu do předmětu byly zpracovány datové sady o 5 a 10 autorech s velikosti textové sady 3.
Zde bude představena datová sada blíže.
data_test = load_dataset_from_path_with_normalization(get_load_path_53(), lambda x: x)
authors_test = pd.read_csv(get_load_path_53(AUTHORS_FILENAME), sep=';')
Creating shorting method with min = 3 Creating lemma method with instance <WordNetLemmatizer> Specified normalize method
data_test.head()
| text | label | |
|---|---|---|
| 0 | THE TRAGEDY OF PUDD'NHEAD WILSON by Mark Twai... | 53 |
| 1 | These chapters are right, now, in every detail... | 53 |
| 2 | Given under my hand this second day of January... | 53 |
| 3 | In 1830 it was a snug collection of modest one... | 53 |
| 4 | Then that house was complete, and its contentm... | 53 |
Každá datová sada byla vytvořena s doprovodným csv souborem obsajujícím další data k autorovi. Jak zde můžeme vidět, víme jak se autor jmenoval.
authors_test.head()
| AuthorId | Author | |
|---|---|---|
| 0 | 761 | Lytton, Edward Bulwer Lytton, Baron |
| 1 | 1800 | Ebers, Georg |
| 2 | 53 | Twain, Mark |
| 3 | 8659 | Kingston, William Henry Giles |
| 4 | 1285 | Parker, Gilbert |
label_counts = data_test.groupby(by=["label"]).size().reset_index(name="counts")
label_counts.head()
| label | counts | |
|---|---|---|
| 0 | 53 | 74224 |
| 1 | 761 | 144504 |
| 2 | 1285 | 133159 |
| 3 | 1800 | 107653 |
| 4 | 8659 | 140495 |
Lze pozorovat nevyváženost datové sady, a proto bude vybrán z každé třídy pouze určitý počet záznamů.
px.bar(data_frame=label_counts, y="counts", barmode="group")
norm_data_test = normalize_dataframe_to_size(data_test, NORMALIZE_LABEL)
label_counts = norm_data_test.groupby(by=["label"]).size().reset_index(name="counts")
px.bar(data_frame=label_counts, y="counts")
Je dobrý přístup textová data před samotným využitím předzpracovat. Toto předzpracování pro příklad zajištuje:
Díky toho získáme některé žadoucí výsledky:
vis = Visualizer()
raw_data = normalize_dataframe_to_size(
load_dataset_from_path_with_normalization(
get_load_path_53(),
None,
PreprocessingType.Raw),
NORMALIZE_LABEL
)
lowered_data = normalize_dataframe_to_size(
load_dataset_from_path_with_normalization(
get_load_path_53(),
None,
PreprocessingType.Lowercase
),
NORMALIZE_LABEL
)
default_data = normalize_dataframe_to_size(
load_dataset_from_path_with_normalization(
get_load_path_53(),
None,
PreprocessingType.Default),
NORMALIZE_LABEL
)
lowerinterpunction_data = normalize_dataframe_to_size(
load_dataset_from_path_with_normalization(
get_load_path_53(),
None,
PreprocessingType.CaseInterpunction),
NORMALIZE_LABEL
)
Creating shorting method with min = 3 Creating lemma method with instance <WordNetLemmatizer> Specified type Raw Creating shorting method with min = 3 Creating lemma method with instance <WordNetLemmatizer> Specified type Lowercase Creating shorting method with min = 3 Creating lemma method with instance <WordNetLemmatizer> Specified type Default Creating shorting method with min = 3 Creating lemma method with instance <WordNetLemmatizer> Specified type CaseInterpunction
print(len(raw_data))
raw_data.head()
75000
| text | label | |
|---|---|---|
| 182816 | Thrice a deep breath the knight relieved did d... | 761 |
| 318651 | But they were not cool deeps by any means, for... | 53 |
| 285365 | retorted Klea. "One of her escort threw them t... | 1800 |
| 322281 | The prince is an educated gentleman. His cultu... | 53 |
| 163774 | I thought: Now, this is the man whom I saw twe... | 53 |
print(len(lowered_data))
lowered_data.head()
75000
| text | label | |
|---|---|---|
| 182816 | thrice a deep breath the knight relieved did d... | 761 |
| 318651 | but they were not cool deeps by any means, for... | 53 |
| 285365 | retorted klea. "one of her escort threw them t... | 1800 |
| 322281 | the prince is an educated gentleman. his cultu... | 53 |
| 163774 | i thought: now, this is the man whom i saw twe... | 53 |
print(len(lowerinterpunction_data))
lowerinterpunction_data.head()
75000
| text | label | |
|---|---|---|
| 182816 | thrice a deep breath the knight relieved did d... | 761 |
| 318651 | but they were not cool deeps by any means for ... | 53 |
| 285365 | retorted klea one of her escort threw them to ... | 1800 |
| 322281 | the prince is an educated gentleman his cultur... | 53 |
| 163774 | i thought now this is the man whom i saw twent... | 53 |
print(len(default_data))
default_data.head()
75000
| text | label | |
|---|---|---|
| 182816 | thrice deep breath knight relieved draw fair v... | 761 |
| 318651 | cool deep mean sun ray weltering hot little br... | 53 |
| 285365 | retorted klea escort threw drop subject pray w... | 1800 |
| 322281 | prince educated gentleman culture european eur... | 53 |
| 163774 | thought man saw year ago occasion went free ha... | 53 |
selected_index = 76
raw_data.text.values[selected_index]
'It was the morning of the twentieth day. At noon we would reach Carson City, the capital of Nevada Territory. We were not glad, but sorry.'
lowered_data.text.values[selected_index]
'it was the morning of the twentieth day. at noon we would reach carson city, the capital of nevada territory. we were not glad, but sorry.'
default_data.text.values[selected_index]
'morning twentieth day noon reach carson city capital nevada territory glad sorry'
lowerinterpunction_data.text.values[selected_index]
'it was the morning of the twentieth day at noon we would reach carson city the capital of nevada territory we were not glad but sorry'
max_min_mean = vis.create_max_min_mean_len([
("RAW", raw_data),
("DEFAULT", default_data),
("LOWER", lowered_data),
("LOWERINTEPUNCTION", lowerinterpunction_data)
])
vis.show_max(max_min_mean)
Zde můžeme pozorovat, že default předzpracování zmenší délku věty na polovinu. Modely nám z tohoto důvodu poběží rychleji a přepokládáme, že tímto předzpracování jsme zároveň nechali nejpodstatnější informace v textových datech.
Toto ale nemusí být pravdou a ověření najdeme ve výsledcích.
vis.show_mean(max_min_mean)
vis.show_min(max_min_mean)
vis.seq_dist(raw_data)
vis.seq_dist(default_data)
vis.seq_dist(lowered_data)
vis.seq_dist(lowerinterpunction_data)
raw_clouds = vis.generate_wordclouds(raw_data)
vis.show_wordcloud(raw_clouds['all'])
<matplotlib.image.AxesImage at 0x7fe86c20e1f0>
default_clouds = vis.generate_wordclouds(default_data)
vis.show_wordcloud(default_clouds['all'])
<matplotlib.image.AxesImage at 0x7fe7c009d3a0>
lower_clouds = vis.generate_wordclouds(lowered_data)
vis.show_wordcloud(lower_clouds['all'])
<matplotlib.image.AxesImage at 0x7fe8b4e32730>
lowerinterpunction_clouds = vis.generate_wordclouds(lowerinterpunction_data)
vis.show_wordcloud(lowerinterpunction_clouds['all'])
<matplotlib.image.AxesImage at 0x7fe7645de3d0>
vis.show_wordcloud(lowerinterpunction_clouds[53])
<matplotlib.image.AxesImage at 0x7fe86c229b50>
vis.show_wordcloud(lowerinterpunction_clouds[8659])
<matplotlib.image.AxesImage at 0x7fe7c0758a90>
raw_top_words = vis.generate_top_words(raw_data)
raw_top_words.keys()
dict_keys([53, 761, 1285, 1800, 8659, 'all'])
raw_top_words[53]
Zde můžeme pozorovat nejzajímavější rozdíl mezi četností jednotlivých slov. Každý z autorů většinou mívá jiná nejvíce používaná slova. Toto může být klíčové pro správnou predikci daného autora.
default_top_words = vis.generate_top_words(default_data)
default_top_words[53]
default_top_words[8659]
default_top_words[1800]
lower_top_words = vis.generate_top_words(lowered_data)
lower_top_words[53]
lowerinterpunction_top_words = vis.generate_top_words(lowerinterpunction_data)
lowerinterpunction_top_words[53]
Jak již jsme zmínili výše, datová sada byla upravena tak, ať pracujeme s vyváženou. Tudíž každý autor má obsaženo stejný počet záznamů. Z toho důvodu bude vybraná metrika přesnost (accuracy).
Tato metrika nám poskytne procentuální hodnotu, která říká jak přesně jsme schopni predikovat, že tento text napsal tento autor.
ALL_KEYS = ['RAW', "LOWER", "DEFAULT", "LOWER_I"]
def load_5():
return {
"RAW": normalize_dataframe_to_size(
load_dataset_from_path_with_normalization(
get_load_path_53(),
None,
PreprocessingType.Raw),
NORMALIZE_LABEL
),
"LOWER": normalize_dataframe_to_size(
load_dataset_from_path_with_normalization(
get_load_path_53(),
None,
PreprocessingType.Lowercase
),
NORMALIZE_LABEL
),
"DEFAULT": normalize_dataframe_to_size(
load_dataset_from_path_with_normalization(
get_load_path_53(),
None,
PreprocessingType.Default),
NORMALIZE_LABEL
),
"LOWER_I": normalize_dataframe_to_size(
load_dataset_from_path_with_normalization(
get_load_path_53(),
None,
PreprocessingType.CaseInterpunction),
NORMALIZE_LABEL
)
}
def load_15():
return {
"RAW": normalize_dataframe_to_size(
load_dataset_from_path_with_normalization(
get_load_path_153(),
None,
PreprocessingType.Raw),
NORMALIZE_LABEL
),
"LOWER": normalize_dataframe_to_size(
load_dataset_from_path_with_normalization(
get_load_path_153(),
None,
PreprocessingType.Lowercase
),
NORMALIZE_LABEL
),
"DEFAULT": normalize_dataframe_to_size(
load_dataset_from_path_with_normalization(
get_load_path_153(),
None,
PreprocessingType.Default),
NORMALIZE_LABEL
),
"LOWER_I": normalize_dataframe_to_size(
load_dataset_from_path_with_normalization(
get_load_path_153(),
None,
PreprocessingType.CaseInterpunction),
NORMALIZE_LABEL
)
}
data_5 = load_5()
Creating shorting method with min = 3 Creating lemma method with instance <WordNetLemmatizer> Specified type Raw Creating shorting method with min = 3 Creating lemma method with instance <WordNetLemmatizer> Specified type Lowercase Creating shorting method with min = 3 Creating lemma method with instance <WordNetLemmatizer> Specified type Default Creating shorting method with min = 3 Creating lemma method with instance <WordNetLemmatizer> Specified type CaseInterpunction
data_15 = load_15()
Creating shorting method with min = 3 Creating lemma method with instance <WordNetLemmatizer> Specified type Raw Creating shorting method with min = 3 Creating lemma method with instance <WordNetLemmatizer> Specified type Lowercase Creating shorting method with min = 3 Creating lemma method with instance <WordNetLemmatizer> Specified type Default Creating shorting method with min = 3 Creating lemma method with instance <WordNetLemmatizer> Specified type CaseInterpunction
data = {
"15": data_15,
"5": data_5
}
data[]
{'5': {'RAW': text label
182816 Thrice a deep breath the knight relieved did d... 761
318651 But they were not cool deeps by any means, for... 53
285365 retorted Klea. "One of her escort threw them t... 1800
322281 The prince is an educated gentleman. His cultu... 53
163774 I thought: Now, this is the man whom I saw twe... 53
... ... ...
52919 They felt that had they done so, they would na... 8659
168164 ***** To W. D. Howells, in America: ... 53
51857 I daresay I shall not succeed at first, but th... 8659
87638 cried Bill. "We have gained an inch, and in an... 8659
259728 But that good man forgot not, even over the wi... 1800
[75000 rows x 2 columns],
'LOWER': text label
182816 thrice a deep breath the knight relieved did d... 761
318651 but they were not cool deeps by any means, for... 53
285365 retorted klea. "one of her escort threw them t... 1800
322281 the prince is an educated gentleman. his cultu... 53
163774 i thought: now, this is the man whom i saw twe... 53
... ... ...
52919 they felt that had they done so, they would na... 8659
168164 ***** to w. d. howells, in america: kaltenleut... 53
51857 i daresay i shall not succeed at first, but th... 8659
87638 cried bill. "we have gained an inch, and in an... 8659
259728 but that good man forgot not, even over the wi... 1800
[75000 rows x 2 columns],
'DEFAULT': text label
182816 thrice deep breath knight relieved draw fair v... 761
318651 cool deep mean sun ray weltering hot little br... 53
285365 retorted klea escort threw drop subject pray w... 1800
322281 prince educated gentleman culture european eur... 53
163774 thought man saw year ago occasion went free ha... 53
... ... ...
52919 felt naturally accused influenced vindictive f... 8659
168164 howells america kaltenleutgeben bei wien aug d... 53
51857 daresay shall succeed like trying piece open g... 8659
87638 cried gained inch minute shall gained inch hurrah 8659
259728 good man forgot wine jar pleasure folk albeit ... 1800
[75000 rows x 2 columns],
'LOWER_I': text label
182816 thrice a deep breath the knight relieved did d... 761
318651 but they were not cool deeps by any means for ... 53
285365 retorted klea one of her escort threw them to ... 1800
322281 the prince is an educated gentleman his cultur... 53
163774 i thought now this is the man whom i saw twent... 53
... ... ...
52919 they felt that had they done so they would nat... 8659
168164 to w d howells in america kaltenleutgeben bei ... 53
51857 i daresay i shall not succeed at first but the... 8659
87638 cried bill we have gained an inch and in anoth... 8659
259728 but that good man forgot not even over the win... 1800
[75000 rows x 2 columns]}}
data['5']
{'RAW': text label
182816 Thrice a deep breath the knight relieved did d... 761
318651 But they were not cool deeps by any means, for... 53
285365 retorted Klea. "One of her escort threw them t... 1800
322281 The prince is an educated gentleman. His cultu... 53
163774 I thought: Now, this is the man whom I saw twe... 53
... ... ...
52919 They felt that had they done so, they would na... 8659
168164 ***** To W. D. Howells, in America: ... 53
51857 I daresay I shall not succeed at first, but th... 8659
87638 cried Bill. "We have gained an inch, and in an... 8659
259728 But that good man forgot not, even over the wi... 1800
[75000 rows x 2 columns],
'LOWER': text label
182816 thrice a deep breath the knight relieved did d... 761
318651 but they were not cool deeps by any means, for... 53
285365 retorted klea. "one of her escort threw them t... 1800
322281 the prince is an educated gentleman. his cultu... 53
163774 i thought: now, this is the man whom i saw twe... 53
... ... ...
52919 they felt that had they done so, they would na... 8659
168164 ***** to w. d. howells, in america: kaltenleut... 53
51857 i daresay i shall not succeed at first, but th... 8659
87638 cried bill. "we have gained an inch, and in an... 8659
259728 but that good man forgot not, even over the wi... 1800
[75000 rows x 2 columns],
'DEFAULT': text label
182816 thrice deep breath knight relieved draw fair v... 761
318651 cool deep mean sun ray weltering hot little br... 53
285365 retorted klea escort threw drop subject pray w... 1800
322281 prince educated gentleman culture european eur... 53
163774 thought man saw year ago occasion went free ha... 53
... ... ...
52919 felt naturally accused influenced vindictive f... 8659
168164 howells america kaltenleutgeben bei wien aug d... 53
51857 daresay shall succeed like trying piece open g... 8659
87638 cried gained inch minute shall gained inch hurrah 8659
259728 good man forgot wine jar pleasure folk albeit ... 1800
[75000 rows x 2 columns],
'LOWER_I': text label
182816 thrice a deep breath the knight relieved did d... 761
318651 but they were not cool deeps by any means for ... 53
285365 retorted klea one of her escort threw them to ... 1800
322281 the prince is an educated gentleman his cultur... 53
163774 i thought now this is the man whom i saw twent... 53
... ... ...
52919 they felt that had they done so they would nat... 8659
168164 to w d howells in america kaltenleutgeben bei ... 53
51857 i daresay i shall not succeed at first but the... 8659
87638 cried bill we have gained an inch and in anoth... 8659
259728 but that good man forgot not even over the win... 1800
[75000 rows x 2 columns]}
BLANK = '-'
class Fields(Enum):
ModelName = 'ModelName'
BatchSize = 'BatchSize'
Optimizer = 'Optimizer'
LR = 'LR'
Epochs = 'Epochs'
EmbeddingSize = 'EmbeddingSize'
Time = 'Time'
Accuracy = 'Accuracy'
Hits = 'Hits'
Miss = 'Miss'
Key = 'Key'
SeqLen = 'SeqLen'
VocabSize = 'VocabSize'
TrainableEmbedding = 'TrainableEmbedding'
ConfMatrix = "ConfMatrix"
ModelType = "Type"
TransformerName = "TransformerName"
NumberOfAuthors= "NumberOfAuthors"
def create_value(
ModelName=BLANK,
BatchSize=BLANK,
Optimizer=BLANK,
Epochs=BLANK,
EmbeddingSize=BLANK,
Time=BLANK,
Accuracy=BLANK,
LR=BLANK,
Hits=BLANK,
Miss=BLANK,
Key=BLANK,
SeqLen=BLANK,
VocabSize=BLANK,
TrainableEmbedding=BLANK,
ConfMatrix=BLANK,
ModelType=BLANK,
TransformerName=BLANK,
NumberOfAuthors=BLANK
):
return {
Fields.ModelName.value: ModelName,
Fields.BatchSize.value: BatchSize,
Fields.Optimizer.value: Optimizer,
Fields.LR.value: LR,
Fields.Epochs.value: Epochs,
Fields.EmbeddingSize.value: EmbeddingSize,
Fields.Time.value: Time,
Fields.Accuracy.value: Accuracy,
Fields.Hits.value: Hits,
Fields.Miss.value: Miss,
Fields.Key.value: Key,
Fields.SeqLen.value: SeqLen,
Fields.VocabSize.value: VocabSize,
Fields.TrainableEmbedding.value: TrainableEmbedding,
Fields.ConfMatrix.value: ConfMatrix,
Fields.ModelType.value: ModelType,
Fields.TransformerName.value: TransformerName,
Fields.NumberOfAuthors.value: NumberOfAuthors
}
BATCH_SIZE = 32
BATCH_SIZES = [
#32,
64,
#128,
#256
]
LR = 0.001
TRANSFORMER_LR = [
0.001,
5e-5,
# 4e-5,
# 3e-5,
# 2e-5
]
ADAM = tf.keras.optimizers.Adam
RMS = tf.keras.optimizers.RMSprop
OPTIMIZERS = [
ADAM,
RMS
]
EMB_SIZES = [
50,
#100,
#150,
#200,
#250,
300
]
EPOCHS = 10
LOSS = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)
METRICS = [tf.keras.metrics.SparseCategoricalAccuracy("accuracy")]
PATIENCE = 3
es = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=PATIENCE, restore_best_weights=True, mode="auto")
def setup_directory():
path = os.path.sep.join(EXPERIMENTS_SAVE_DIRECTORY)
index = len(os.listdir(os.path.sep.join(EXPERIMENTS_SAVE_DIRECTORY)))
path = os.path.sep.join([path, str(index)])
if not os.path.exists(path):
os.makedirs(path)
return path
def save_experiment_info(
path,
ModelName=BLANK,
BatchSize=BLANK,
Optimizer=BLANK,
Epochs=BLANK,
EmbeddingSize=BLANK,
Time=BLANK,
Accuracy=BLANK,
LR=BLANK,
Hits=BLANK,
Miss=BLANK,
Key=BLANK,
SeqLen=BLANK,
VocabSize=BLANK,
TrainableEmbedding=BLANK,
ConfMatrix=BLANK,
ModelType=BLANK,
TransformerName=BLANK,
NumberOfAuthors=BLANK
):
val = create_value(
ModelName=ModelName,
BatchSize=BatchSize,
Optimizer=Optimizer,
Epochs=BLANK,
EmbeddingSize=EmbeddingSize,
Time=Time,
Accuracy=Accuracy,
LR=LR,
Hits=Hits,
Miss=Miss,
Key=Key,
SeqLen=SeqLen,
VocabSize=VocabSize,
TrainableEmbedding=TrainableEmbedding,
ConfMatrix=ConfMatrix,
ModelType=ModelType,
TransformerName=TransformerName,
NumberOfAuthors=NumberOfAuthors
)
df = pd.DataFrame.from_dict(val, orient="index")
path = os.path.sep.join([path, SUMMAR])
print(f"Saving to {path}")
df.to_csv(path, sep=';')
return df
from sklearn.metrics import confusion_matrix
from sklearn.metrics import accuracy_score
def run_dense_model(
max_tokens,
output_sequence_length,
number_of_authors,
emb_size,
key,
loss,
optimizer,
metrics,
batch_size,
epochs,
lr
):
MODEL_NAME = "DENSE"
current_path = setup_directory()
current_data = data[str(number_of_authors)][key]
loader = get_load_path_53 if number_of_authors == 5 else get_load_path_153
encoder = create_encoder_from_path(loader(AUTHORS_FILENAME))
X_train, X_valid, X_test, y_train, y_valid, y_test = split_dataframe_to_train_test_valid(current_data)
y_test = encoder.transform(y_test)
y_train = encoder.transform(y_train)
y_valid = encoder.transform(y_valid)
train_ds = create_dataset_from_Xy(X_train, y_train)
test_ds = create_dataset_from_Xy(X_test, y_test)
valid_ds = create_dataset_from_Xy(X_valid, y_valid)
vector_layer = tf.keras.layers.TextVectorization(
max_tokens=max_tokens,
output_mode='int',
standardize=None,
output_sequence_length=output_sequence_length,
split='whitespace'
)
vector_layer.adapt(train_ds.map(lambda x, y: x))
model = tf.keras.Sequential()
model.add(tf.keras.Input(shape=(1,), dtype=tf.string))
model.add(vector_layer)
model.add(
tf.keras.layers.Embedding(
max_tokens + 1,
emb_size,
mask_zero = True
)
)
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(64, activation='relu'))
model.add(tf.keras.layers.Dropout(rate=0.2))
model.add(tf.keras.layers.Dense(32, activation='relu'))
model.add(tf.keras.layers.Dropout(rate=0.4))
model.add(tf.keras.layers.Dense(64, activation='relu'))
model.add(tf.keras.layers.Dropout(rate=0.2))
model.add(tf.keras.layers.Dense(number_of_authors, activation='softmax'))
optimizer = optimizer(learning_rate=lr)
model.compile(
loss=loss,
optimizer=optimizer,
metrics=metrics,
)
history = model.fit(
train_ds.batch(batch_size),
validation_data=valid_ds.batch(1),
epochs=epochs,
callbacks=[
CSVLogger(current_path),
es
]
)
prediction = model.predict(test_ds.batch(1))
y_pred = prediction_to_labels(prediction)
accuracy = accuracy_score(y_test, y_pred)
conf_matrix = confusion_matrix(y_test, y_pred)
return save_experiment_info(
current_path,
ModelName=MODEL_NAME,
BatchSize=batch_size,
Optimizer=type(optimizer).__name__,
Epochs=epochs,
EmbeddingSize=emb_size,
Time=BLANK,
Accuracy=accuracy,
LR=lr,
Hits=0,
Miss=0,
Key=key,
SeqLen=output_sequence_length,
VocabSize=max_tokens,
TrainableEmbedding=True,
ConfMatrix=conf_matrix,
ModelType="NORMAL",
TransformerName=BLANK,
NumberOfAuthors=number_of_authors
)
def generate_model_1_experiments():
for embedding_size in EMB_SIZES:
for vocab_size in [10000]:
for author in [5, 15]:
for seq_len in [200, 400]:
for key in ALL_KEYS:
for optimizer in [ADAM]:
for batch_size in BATCH_SIZES:
for epoch in [EPOCHS]:
for lr in [LR]:
yield lr, embedding_size, vocab_size, seq_len, key, optimizer, batch_size, epoch, author
len(list(generate_model_1_experiments()))
32
for exp_values in generate_model_1_experiments():
lr, embedding_size, vocab_size, seq_len, key, optimizer, batch_size, epoch, author = exp_values
run_dense_model(
max_tokens=vocab_size,
output_sequence_length=seq_len,
number_of_authors=author,
emb_size=embedding_size,
key=key,
loss=LOSS,
optimizer=optimizer,
metrics=METRICS,
batch_size=batch_size,
epochs=epoch,
lr=lr
)
Train (57375,) Valid (6375,) Test (11250,) Epoch 1/10 897/897 [==============================] - 20s 21ms/step - loss: 1.1092 - accuracy: 0.5457 - val_loss: 0.6625 - val_accuracy: 0.7520 - time: 19.6096 Epoch 2/10 897/897 [==============================] - 19s 21ms/step - loss: 0.5454 - accuracy: 0.8149 - val_loss: 0.6588 - val_accuracy: 0.7741 - time: 38.7481 Epoch 3/10 897/897 [==============================] - 19s 21ms/step - loss: 0.3251 - accuracy: 0.8951 - val_loss: 0.7480 - val_accuracy: 0.7776 - time: 57.6086 Epoch 4/10 897/897 [==============================] - 19s 21ms/step - loss: 0.2037 - accuracy: 0.9366 - val_loss: 0.8819 - val_accuracy: 0.7727 - time: 76.6590 Epoch 5/10 897/897 [==============================] - 19s 21ms/step - loss: 0.1362 - accuracy: 0.9587 - val_loss: 1.0665 - val_accuracy: 0.7686 - time: 95.5141 Saving to ./experiments/1/sum.csv Train (57375,) Valid (6375,) Test (11250,) Epoch 1/10 897/897 [==============================] - 20s 22ms/step - loss: 1.1327 - accuracy: 0.5528 - val_loss: 0.6745 - val_accuracy: 0.7509 - time: 20.0904 Epoch 2/10 897/897 [==============================] - 19s 21ms/step - loss: 0.5607 - accuracy: 0.8075 - val_loss: 0.6359 - val_accuracy: 0.7798 - time: 39.3553 Epoch 3/10 897/897 [==============================] - 19s 21ms/step - loss: 0.3235 - accuracy: 0.8943 - val_loss: 0.7423 - val_accuracy: 0.7864 - time: 58.5663 Epoch 4/10 897/897 [==============================] - 19s 21ms/step - loss: 0.1981 - accuracy: 0.9376 - val_loss: 0.9064 - val_accuracy: 0.7813 - time: 77.7837 Epoch 5/10 897/897 [==============================] - 19s 21ms/step - loss: 0.1383 - accuracy: 0.9579 - val_loss: 1.0002 - val_accuracy: 0.7838 - time: 96.9771 Saving to ./experiments/2/sum.csv Train (57375,) Valid (6375,) Test (11250,) Epoch 1/10 897/897 [==============================] - 19s 21ms/step - loss: 1.2255 - accuracy: 0.4703 - val_loss: 0.8191 - val_accuracy: 0.6439 - time: 18.9839 Epoch 2/10 897/897 [==============================] - 18s 20ms/step - loss: 0.6994 - accuracy: 0.7103 - val_loss: 0.6556 - val_accuracy: 0.7667 - time: 37.2729 Epoch 3/10 897/897 [==============================] - 18s 21ms/step - loss: 0.4104 - accuracy: 0.8544 - val_loss: 0.6998 - val_accuracy: 0.7926 - time: 55.8114 Epoch 4/10 897/897 [==============================] - 18s 20ms/step - loss: 0.2596 - accuracy: 0.9115 - val_loss: 0.8018 - val_accuracy: 0.7928 - time: 74.0906 Epoch 5/10 897/897 [==============================] - 18s 20ms/step - loss: 0.1830 - accuracy: 0.9377 - val_loss: 0.9401 - val_accuracy: 0.7890 - time: 92.3428 Saving to ./experiments/3/sum.csv Train (57375,) Valid (6375,) Test (11250,) Epoch 1/10 897/897 [==============================] - 20s 21ms/step - loss: 1.1458 - accuracy: 0.5311 - val_loss: 0.6403 - val_accuracy: 0.7641 - time: 19.8549 Epoch 2/10 897/897 [==============================] - 19s 21ms/step - loss: 0.5118 - accuracy: 0.8236 - val_loss: 0.5438 - val_accuracy: 0.8133 - time: 38.8472 Epoch 3/10 897/897 [==============================] - 19s 21ms/step - loss: 0.2995 - accuracy: 0.9016 - val_loss: 0.6463 - val_accuracy: 0.8118 - time: 58.1010 Epoch 4/10 897/897 [==============================] - 19s 21ms/step - loss: 0.1938 - accuracy: 0.9365 - val_loss: 0.8115 - val_accuracy: 0.8100 - time: 77.2383 Epoch 5/10 897/897 [==============================] - 19s 21ms/step - loss: 0.1460 - accuracy: 0.9526 - val_loss: 0.8223 - val_accuracy: 0.8096 - time: 96.3258 Saving to ./experiments/4/sum.csv Train (57375,) Valid (6375,) Test (11250,) Epoch 1/10 897/897 [==============================] - 24s 26ms/step - loss: 1.2396 - accuracy: 0.4879 - val_loss: 0.7236 - val_accuracy: 0.7349 - time: 24.0312 Epoch 2/10 897/897 [==============================] - 23s 26ms/step - loss: 0.6198 - accuracy: 0.7835 - val_loss: 0.5899 - val_accuracy: 0.7900 - time: 47.3286 Epoch 3/10 897/897 [==============================] - 23s 26ms/step - loss: 0.3762 - accuracy: 0.8778 - val_loss: 0.6444 - val_accuracy: 0.7909 - time: 70.3903 Epoch 4/10 897/897 [==============================] - 23s 26ms/step - loss: 0.2353 - accuracy: 0.9266 - val_loss: 0.7783 - val_accuracy: 0.7904 - time: 93.5067 Epoch 5/10 897/897 [==============================] - 23s 26ms/step - loss: 0.1597 - accuracy: 0.9511 - val_loss: 0.8636 - val_accuracy: 0.7906 - time: 116.8653 Saving to ./experiments/5/sum.csv Train (57375,) Valid (6375,) Test (11250,) Epoch 1/10 897/897 [==============================] - 24s 26ms/step - loss: 1.2563 - accuracy: 0.4734 - val_loss: 0.8248 - val_accuracy: 0.6913 - time: 23.6612 Epoch 2/10 897/897 [==============================] - 23s 26ms/step - loss: 0.6661 - accuracy: 0.7624 - val_loss: 0.6137 - val_accuracy: 0.7837 - time: 46.7763 Epoch 3/10 897/897 [==============================] - 23s 26ms/step - loss: 0.3902 - accuracy: 0.8723 - val_loss: 0.6765 - val_accuracy: 0.7818 - time: 69.9891 Epoch 4/10 897/897 [==============================] - 23s 25ms/step - loss: 0.2441 - accuracy: 0.9226 - val_loss: 0.7264 - val_accuracy: 0.7920 - time: 92.5897 Epoch 5/10 897/897 [==============================] - 22s 25ms/step - loss: 0.1609 - accuracy: 0.9496 - val_loss: 0.8897 - val_accuracy: 0.7873 - time: 114.9780 Saving to ./experiments/6/sum.csv Train (57375,) Valid (6375,) Test (11250,) Epoch 1/10 897/897 [==============================] - 23s 25ms/step - loss: 1.4970 - accuracy: 0.3287 - val_loss: 1.0639 - val_accuracy: 0.5627 - time: 22.8250 Epoch 2/10 897/897 [==============================] - 22s 24ms/step - loss: 0.8271 - accuracy: 0.6617 - val_loss: 0.6806 - val_accuracy: 0.7440 - time: 44.7483 Epoch 3/10 897/897 [==============================] - 22s 25ms/step - loss: 0.4942 - accuracy: 0.8202 - val_loss: 0.6229 - val_accuracy: 0.7909 - time: 66.9495 Epoch 4/10 897/897 [==============================] - 22s 25ms/step - loss: 0.3114 - accuracy: 0.8923 - val_loss: 0.6957 - val_accuracy: 0.8049 - time: 89.0135 Epoch 5/10 897/897 [==============================] - 22s 25ms/step - loss: 0.2108 - accuracy: 0.9272 - val_loss: 0.8779 - val_accuracy: 0.7997 - time: 111.1830 Epoch 6/10 897/897 [==============================] - 22s 25ms/step - loss: 0.1557 - accuracy: 0.9475 - val_loss: 0.9966 - val_accuracy: 0.7965 - time: 133.2216 Saving to ./experiments/7/sum.csv Train (57375,) Valid (6375,) Test (11250,) Epoch 1/10 897/897 [==============================] - 24s 26ms/step - loss: 1.2243 - accuracy: 0.4825 - val_loss: 0.7541 - val_accuracy: 0.7225 - time: 23.6197 Epoch 2/10 897/897 [==============================] - 23s 25ms/step - loss: 0.5908 - accuracy: 0.7896 - val_loss: 0.5499 - val_accuracy: 0.8060 - time: 46.3566 Epoch 3/10 897/897 [==============================] - 22s 25ms/step - loss: 0.3404 - accuracy: 0.8870 - val_loss: 0.5833 - val_accuracy: 0.8177 - time: 68.8512 Epoch 4/10 897/897 [==============================] - 22s 25ms/step - loss: 0.2177 - accuracy: 0.9282 - val_loss: 0.7044 - val_accuracy: 0.8138 - time: 91.2466 Epoch 5/10 897/897 [==============================] - 23s 25ms/step - loss: 0.1563 - accuracy: 0.9484 - val_loss: 0.8454 - val_accuracy: 0.8041 - time: 114.0871 Saving to ./experiments/8/sum.csv Train (172125,) Valid (19125,) Test (33750,) Epoch 1/10 2690/2690 [==============================] - 58s 21ms/step - loss: 2.0389 - accuracy: 0.3196 - val_loss: 1.6315 - val_accuracy: 0.4372 - time: 58.1988 Epoch 2/10 2690/2690 [==============================] - 58s 22ms/step - loss: 1.5636 - accuracy: 0.4617 - val_loss: 1.4995 - val_accuracy: 0.4932 - time: 116.4055 Epoch 3/10 2690/2690 [==============================] - 58s 22ms/step - loss: 1.3251 - accuracy: 0.5466 - val_loss: 1.4742 - val_accuracy: 0.5173 - time: 174.3456 Epoch 4/10 2690/2690 [==============================] - 57s 21ms/step - loss: 1.1326 - accuracy: 0.6159 - val_loss: 1.5002 - val_accuracy: 0.5404 - time: 231.5611 Epoch 5/10 2690/2690 [==============================] - 57s 21ms/step - loss: 0.9792 - accuracy: 0.6727 - val_loss: 1.5381 - val_accuracy: 0.5526 - time: 288.9038 Epoch 6/10 2690/2690 [==============================] - 58s 22ms/step - loss: 0.8644 - accuracy: 0.7159 - val_loss: 1.6331 - val_accuracy: 0.5532 - time: 346.8530 Saving to ./experiments/9/sum.csv Train (172125,) Valid (19125,) Test (33750,) Epoch 1/10 2690/2690 [==============================] - 58s 21ms/step - loss: 2.0078 - accuracy: 0.3345 - val_loss: 1.6180 - val_accuracy: 0.4401 - time: 57.7445 Epoch 2/10 2690/2690 [==============================] - 57s 21ms/step - loss: 1.5541 - accuracy: 0.4640 - val_loss: 1.4908 - val_accuracy: 0.4930 - time: 115.1954 Epoch 3/10 2690/2690 [==============================] - 57s 21ms/step - loss: 1.3211 - accuracy: 0.5473 - val_loss: 1.4854 - val_accuracy: 0.5198 - time: 171.8556 Epoch 4/10 2690/2690 [==============================] - 56s 21ms/step - loss: 1.1390 - accuracy: 0.6110 - val_loss: 1.5313 - val_accuracy: 0.5374 - time: 228.2602 Epoch 5/10 2690/2690 [==============================] - 57s 21ms/step - loss: 0.9905 - accuracy: 0.6658 - val_loss: 1.5988 - val_accuracy: 0.5395 - time: 284.9560 Epoch 6/10 2690/2690 [==============================] - 55s 20ms/step - loss: 0.8797 - accuracy: 0.7073 - val_loss: 1.6603 - val_accuracy: 0.5481 - time: 339.6589 Saving to ./experiments/10/sum.csv Train (172125,) Valid (19125,) Test (33750,) Epoch 1/10 2690/2690 [==============================] - 55s 20ms/step - loss: 2.0945 - accuracy: 0.3059 - val_loss: 1.6696 - val_accuracy: 0.4403 - time: 55.1934 Epoch 2/10 2690/2690 [==============================] - 55s 20ms/step - loss: 1.5408 - accuracy: 0.4838 - val_loss: 1.3899 - val_accuracy: 0.5398 - time: 109.7722 Epoch 3/10 2690/2690 [==============================] - 55s 20ms/step - loss: 1.2604 - accuracy: 0.5822 - val_loss: 1.3462 - val_accuracy: 0.5763 - time: 164.6123 Epoch 4/10 2690/2690 [==============================] - 53s 20ms/step - loss: 1.0757 - accuracy: 0.6473 - val_loss: 1.3511 - val_accuracy: 0.5968 - time: 217.2784 Epoch 5/10 2690/2690 [==============================] - 55s 20ms/step - loss: 0.9294 - accuracy: 0.7023 - val_loss: 1.3809 - val_accuracy: 0.6075 - time: 271.9472 Epoch 6/10 2690/2690 [==============================] - 53s 20ms/step - loss: 0.8217 - accuracy: 0.7410 - val_loss: 1.4647 - val_accuracy: 0.6103 - time: 325.1410 Saving to ./experiments/11/sum.csv Train (172125,) Valid (19125,) Test (33750,) Epoch 1/10 2690/2690 [==============================] - 58s 21ms/step - loss: 2.0427 - accuracy: 0.3319 - val_loss: 1.6068 - val_accuracy: 0.4555 - time: 57.6614 Epoch 2/10 2690/2690 [==============================] - 57s 21ms/step - loss: 1.4889 - accuracy: 0.4952 - val_loss: 1.3843 - val_accuracy: 0.5312 - time: 114.4740 Epoch 3/10 2690/2690 [==============================] - 57s 21ms/step - loss: 1.2211 - accuracy: 0.5902 - val_loss: 1.3423 - val_accuracy: 0.5706 - time: 171.9404 Epoch 4/10 2690/2690 [==============================] - 56s 21ms/step - loss: 1.0366 - accuracy: 0.6577 - val_loss: 1.3896 - val_accuracy: 0.5866 - time: 228.2280 Epoch 5/10 2690/2690 [==============================] - 56s 21ms/step - loss: 0.8927 - accuracy: 0.7113 - val_loss: 1.4266 - val_accuracy: 0.5997 - time: 284.4357 Epoch 6/10 2690/2690 [==============================] - 56s 21ms/step - loss: 0.7831 - accuracy: 0.7487 - val_loss: 1.4492 - val_accuracy: 0.6127 - time: 340.2862 Saving to ./experiments/12/sum.csv Train (172125,) Valid (19125,) Test (33750,) Epoch 1/10 2690/2690 [==============================] - 68s 25ms/step - loss: 2.0768 - accuracy: 0.3215 - val_loss: 1.6491 - val_accuracy: 0.4290 - time: 68.1557 Epoch 2/10 2690/2690 [==============================] - 65s 24ms/step - loss: 1.5745 - accuracy: 0.4561 - val_loss: 1.4989 - val_accuracy: 0.4890 - time: 133.4462 Epoch 3/10 2690/2690 [==============================] - 67s 25ms/step - loss: 1.3356 - accuracy: 0.5427 - val_loss: 1.4696 - val_accuracy: 0.5233 - time: 200.5947 Epoch 4/10 2690/2690 [==============================] - 67s 25ms/step - loss: 1.1526 - accuracy: 0.6099 - val_loss: 1.4971 - val_accuracy: 0.5414 - time: 268.0899 Epoch 5/10 2690/2690 [==============================] - 68s 25ms/step - loss: 0.9991 - accuracy: 0.6648 - val_loss: 1.5784 - val_accuracy: 0.5458 - time: 336.0488 Epoch 6/10 2690/2690 [==============================] - 67s 25ms/step - loss: 0.8811 - accuracy: 0.7091 - val_loss: 1.6463 - val_accuracy: 0.5537 - time: 403.1251 Saving to ./experiments/13/sum.csv Train (172125,) Valid (19125,) Test (33750,) Epoch 1/10 2690/2690 [==============================] - 69s 26ms/step - loss: 2.1037 - accuracy: 0.3074 - val_loss: 1.6977 - val_accuracy: 0.4146 - time: 69.4531 Epoch 2/10 2690/2690 [==============================] - 67s 25ms/step - loss: 1.6014 - accuracy: 0.4520 - val_loss: 1.4982 - val_accuracy: 0.4971 - time: 136.9077 Epoch 3/10 2690/2690 [==============================] - 68s 25ms/step - loss: 1.3412 - accuracy: 0.5446 - val_loss: 1.4765 - val_accuracy: 0.5253 - time: 204.5841 Epoch 4/10 2690/2690 [==============================] - 68s 25ms/step - loss: 1.1489 - accuracy: 0.6122 - val_loss: 1.5000 - val_accuracy: 0.5382 - time: 272.7900 Epoch 5/10 2690/2690 [==============================] - 68s 25ms/step - loss: 0.9979 - accuracy: 0.6655 - val_loss: 1.5711 - val_accuracy: 0.5493 - time: 340.7916 Epoch 6/10 2690/2690 [==============================] - 69s 25ms/step - loss: 0.8836 - accuracy: 0.7058 - val_loss: 1.6070 - val_accuracy: 0.5548 - time: 409.3469 Saving to ./experiments/14/sum.csv Train (172125,) Valid (19125,) Test (33750,) Epoch 1/10 2690/2690 [==============================] - 65s 24ms/step - loss: 2.1489 - accuracy: 0.2800 - val_loss: 1.8111 - val_accuracy: 0.3535 - time: 64.8290 Epoch 2/10 2690/2690 [==============================] - 66s 25ms/step - loss: 1.7113 - accuracy: 0.3945 - val_loss: 1.5883 - val_accuracy: 0.4578 - time: 130.9426 Epoch 3/10 2690/2690 [==============================] - 65s 24ms/step - loss: 1.4490 - accuracy: 0.4916 - val_loss: 1.4911 - val_accuracy: 0.4995 - time: 195.9402 Epoch 4/10 2690/2690 [==============================] - 65s 24ms/step - loss: 1.2560 - accuracy: 0.5620 - val_loss: 1.4724 - val_accuracy: 0.5236 - time: 260.5758 Epoch 5/10 2690/2690 [==============================] - 65s 24ms/step - loss: 1.1043 - accuracy: 0.6202 - val_loss: 1.4690 - val_accuracy: 0.5491 - time: 326.0909 Epoch 6/10 2690/2690 [==============================] - 66s 24ms/step - loss: 0.9787 - accuracy: 0.6690 - val_loss: 1.4998 - val_accuracy: 0.5604 - time: 391.7768 Epoch 7/10 2690/2690 [==============================] - 66s 24ms/step - loss: 0.8763 - accuracy: 0.7058 - val_loss: 1.5598 - val_accuracy: 0.5729 - time: 457.4595 Epoch 8/10 2690/2690 [==============================] - 65s 24ms/step - loss: 0.8010 - accuracy: 0.7346 - val_loss: 1.5760 - val_accuracy: 0.5733 - time: 522.4126 Saving to ./experiments/15/sum.csv Train (172125,) Valid (19125,) Test (33750,) Epoch 1/10 2690/2690 [==============================] - 69s 25ms/step - loss: 2.1053 - accuracy: 0.3026 - val_loss: 1.6754 - val_accuracy: 0.4143 - time: 68.5011 Epoch 2/10 2690/2690 [==============================] - 68s 25ms/step - loss: 1.5558 - accuracy: 0.4586 - val_loss: 1.4043 - val_accuracy: 0.5211 - time: 136.1559 Epoch 3/10 2690/2690 [==============================] - 68s 25ms/step - loss: 1.2706 - accuracy: 0.5615 - val_loss: 1.3419 - val_accuracy: 0.5576 - time: 204.3615 Epoch 4/10 2690/2690 [==============================] - 68s 25ms/step - loss: 1.0901 - accuracy: 0.6244 - val_loss: 1.3487 - val_accuracy: 0.5815 - time: 272.1085 Epoch 5/10 2690/2690 [==============================] - 68s 25ms/step - loss: 0.9526 - accuracy: 0.6788 - val_loss: 1.3749 - val_accuracy: 0.5987 - time: 339.7991 Epoch 6/10 2690/2690 [==============================] - 69s 26ms/step - loss: 0.8404 - accuracy: 0.7207 - val_loss: 1.4078 - val_accuracy: 0.6082 - time: 408.7929 Saving to ./experiments/16/sum.csv Train (57375,) Valid (6375,) Test (11250,) Epoch 1/10 897/897 [==============================] - 43s 47ms/step - loss: 1.2061 - accuracy: 0.5213 - val_loss: 0.7380 - val_accuracy: 0.7222 - time: 42.8878 Epoch 2/10 897/897 [==============================] - 41s 46ms/step - loss: 0.5877 - accuracy: 0.7960 - val_loss: 0.6602 - val_accuracy: 0.7658 - time: 84.3094 Epoch 3/10 897/897 [==============================] - 41s 46ms/step - loss: 0.3010 - accuracy: 0.9012 - val_loss: 0.7774 - val_accuracy: 0.7722 - time: 125.6398 Epoch 4/10 897/897 [==============================] - 41s 46ms/step - loss: 0.1749 - accuracy: 0.9461 - val_loss: 0.9172 - val_accuracy: 0.7683 - time: 166.7339 Epoch 5/10 897/897 [==============================] - 41s 46ms/step - loss: 0.1109 - accuracy: 0.9665 - val_loss: 1.0060 - val_accuracy: 0.7639 - time: 207.7955 Saving to ./experiments/17/sum.csv Train (57375,) Valid (6375,) Test (11250,) Epoch 1/10 897/897 [==============================] - 43s 47ms/step - loss: 1.1219 - accuracy: 0.5644 - val_loss: 0.6812 - val_accuracy: 0.7573 - time: 42.9350 Epoch 2/10 897/897 [==============================] - 42s 46ms/step - loss: 0.5310 - accuracy: 0.8199 - val_loss: 0.6584 - val_accuracy: 0.7799 - time: 84.6618 Epoch 3/10 897/897 [==============================] - 42s 46ms/step - loss: 0.2790 - accuracy: 0.9114 - val_loss: 0.7617 - val_accuracy: 0.7885 - time: 126.4006 Epoch 4/10 897/897 [==============================] - 41s 46ms/step - loss: 0.1603 - accuracy: 0.9515 - val_loss: 0.8541 - val_accuracy: 0.7856 - time: 167.3357 Epoch 5/10 897/897 [==============================] - 42s 46ms/step - loss: 0.1132 - accuracy: 0.9660 - val_loss: 0.9667 - val_accuracy: 0.7854 - time: 208.9221 Saving to ./experiments/18/sum.csv Train (57375,) Valid (6375,) Test (11250,) Epoch 1/10 897/897 [==============================] - 41s 45ms/step - loss: 1.2983 - accuracy: 0.4404 - val_loss: 0.8450 - val_accuracy: 0.6599 - time: 41.1660 Epoch 2/10 897/897 [==============================] - 40s 45ms/step - loss: 0.6719 - accuracy: 0.7503 - val_loss: 0.6063 - val_accuracy: 0.7831 - time: 81.3468 Epoch 3/10 897/897 [==============================] - 39s 43ms/step - loss: 0.3585 - accuracy: 0.8751 - val_loss: 0.6537 - val_accuracy: 0.8036 - time: 120.3311 Epoch 4/10 897/897 [==============================] - 40s 45ms/step - loss: 0.2178 - accuracy: 0.9277 - val_loss: 0.7792 - val_accuracy: 0.8030 - time: 160.6190 Epoch 5/10 897/897 [==============================] - 41s 46ms/step - loss: 0.1476 - accuracy: 0.9496 - val_loss: 0.8659 - val_accuracy: 0.8031 - time: 201.4358 Saving to ./experiments/19/sum.csv Train (57375,) Valid (6375,) Test (11250,) Epoch 1/10 897/897 [==============================] - 42s 46ms/step - loss: 1.1131 - accuracy: 0.5664 - val_loss: 0.6342 - val_accuracy: 0.7694 - time: 42.3203 Epoch 2/10 897/897 [==============================] - 41s 45ms/step - loss: 0.4970 - accuracy: 0.8301 - val_loss: 0.5627 - val_accuracy: 0.8014 - time: 83.0308 Epoch 3/10 897/897 [==============================] - 41s 45ms/step - loss: 0.2636 - accuracy: 0.9124 - val_loss: 0.6593 - val_accuracy: 0.8094 - time: 123.8527 Epoch 4/10 897/897 [==============================] - 40s 45ms/step - loss: 0.1684 - accuracy: 0.9451 - val_loss: 0.7868 - val_accuracy: 0.8050 - time: 164.2680 Epoch 5/10 897/897 [==============================] - 41s 46ms/step - loss: 0.1226 - accuracy: 0.9602 - val_loss: 0.8197 - val_accuracy: 0.7995 - time: 205.1059 Saving to ./experiments/20/sum.csv Train (57375,) Valid (6375,) Test (11250,) Epoch 1/10 897/897 [==============================] - 76s 84ms/step - loss: 1.4314 - accuracy: 0.3718 - val_loss: 1.1860 - val_accuracy: 0.4921 - time: 75.8315 Epoch 2/10 897/897 [==============================] - 73s 81ms/step - loss: 0.9765 - accuracy: 0.5907 - val_loss: 0.7914 - val_accuracy: 0.7023 - time: 148.4724 Epoch 3/10 897/897 [==============================] - 73s 82ms/step - loss: 0.5856 - accuracy: 0.7853 - val_loss: 0.6904 - val_accuracy: 0.7548 - time: 221.8757 Epoch 4/10 897/897 [==============================] - 71s 80ms/step - loss: 0.3427 - accuracy: 0.8844 - val_loss: 0.7478 - val_accuracy: 0.7689 - time: 293.3946 Epoch 5/10 897/897 [==============================] - 69s 77ms/step - loss: 0.2151 - accuracy: 0.9300 - val_loss: 0.8492 - val_accuracy: 0.7686 - time: 362.3219 Epoch 6/10 897/897 [==============================] - 71s 79ms/step - loss: 0.1472 - accuracy: 0.9546 - val_loss: 1.0010 - val_accuracy: 0.7689 - time: 433.5334 Saving to ./experiments/21/sum.csv Train (57375,) Valid (6375,) Test (11250,) Epoch 1/10 897/897 [==============================] - 72s 80ms/step - loss: 1.2526 - accuracy: 0.4923 - val_loss: 0.8532 - val_accuracy: 0.6709 - time: 72.2412 Epoch 2/10 897/897 [==============================] - 71s 79ms/step - loss: 0.7181 - accuracy: 0.7276 - val_loss: 0.7140 - val_accuracy: 0.7448 - time: 143.0743 Epoch 3/10 897/897 [==============================] - 72s 80ms/step - loss: 0.4157 - accuracy: 0.8593 - val_loss: 0.6968 - val_accuracy: 0.7798 - time: 214.9911 Epoch 4/10 897/897 [==============================] - 71s 79ms/step - loss: 0.2424 - accuracy: 0.9235 - val_loss: 0.7680 - val_accuracy: 0.7853 - time: 286.2484 Epoch 5/10 897/897 [==============================] - 72s 80ms/step - loss: 0.1611 - accuracy: 0.9519 - val_loss: 0.8226 - val_accuracy: 0.7791 - time: 358.0559 Epoch 6/10 897/897 [==============================] - 71s 79ms/step - loss: 0.1120 - accuracy: 0.9671 - val_loss: 0.9719 - val_accuracy: 0.7857 - time: 428.7859 Saving to ./experiments/22/sum.csv Train (57375,) Valid (6375,) Test (11250,) Epoch 1/10 897/897 [==============================] - 72s 79ms/step - loss: 1.6130 - accuracy: 0.2575 - val_loss: 1.6093 - val_accuracy: 0.1923 - time: 71.8384 Epoch 2/10 897/897 [==============================] - 71s 79ms/step - loss: 1.6099 - accuracy: 0.2005 - val_loss: 1.6093 - val_accuracy: 0.1992 - time: 143.0636 Epoch 3/10 897/897 [==============================] - 72s 81ms/step - loss: 1.6103 - accuracy: 0.2002 - val_loss: 1.6096 - val_accuracy: 0.1920 - time: 215.4408 Epoch 4/10 897/897 [==============================] - 72s 81ms/step - loss: 1.6100 - accuracy: 0.2012 - val_loss: 1.6096 - val_accuracy: 0.1987 - time: 287.8035 Epoch 5/10 897/897 [==============================] - 73s 82ms/step - loss: 1.6105 - accuracy: 0.1998 - val_loss: 1.6096 - val_accuracy: 0.1922 - time: 360.9053 Saving to ./experiments/23/sum.csv Train (57375,) Valid (6375,) Test (11250,) Epoch 1/10 897/897 [==============================] - 73s 81ms/step - loss: 1.6123 - accuracy: 0.1997 - val_loss: 1.6084 - val_accuracy: 0.1922 - time: 72.8000 Epoch 2/10 897/897 [==============================] - 71s 79ms/step - loss: 1.4137 - accuracy: 0.3396 - val_loss: 1.1041 - val_accuracy: 0.5329 - time: 144.0263 Epoch 3/10 897/897 [==============================] - 71s 79ms/step - loss: 0.8928 - accuracy: 0.6406 - val_loss: 0.7459 - val_accuracy: 0.7194 - time: 215.0285 Epoch 4/10 897/897 [==============================] - 72s 80ms/step - loss: 0.5330 - accuracy: 0.8096 - val_loss: 0.6141 - val_accuracy: 0.7989 - time: 286.6653 Epoch 5/10 897/897 [==============================] - 73s 82ms/step - loss: 0.3148 - accuracy: 0.8956 - val_loss: 0.6382 - val_accuracy: 0.8080 - time: 360.2254 Epoch 6/10 897/897 [==============================] - 73s 81ms/step - loss: 0.2174 - accuracy: 0.9292 - val_loss: 0.6781 - val_accuracy: 0.8146 - time: 432.7831 Epoch 7/10 897/897 [==============================] - 71s 79ms/step - loss: 0.1608 - accuracy: 0.9483 - val_loss: 0.7662 - val_accuracy: 0.8127 - time: 503.9513 Saving to ./experiments/24/sum.csv Train (172125,) Valid (19125,) Test (33750,) Epoch 1/10 2690/2690 [==============================] - 122s 45ms/step - loss: 2.0102 - accuracy: 0.3290 - val_loss: 1.6453 - val_accuracy: 0.4280 - time: 122.0549 Epoch 2/10 2690/2690 [==============================] - 120s 45ms/step - loss: 1.5291 - accuracy: 0.4641 - val_loss: 1.5359 - val_accuracy: 0.4795 - time: 241.9263 Epoch 3/10 2690/2690 [==============================] - 122s 45ms/step - loss: 1.2336 - accuracy: 0.5628 - val_loss: 1.5504 - val_accuracy: 0.5069 - time: 364.0079 Epoch 4/10 2690/2690 [==============================] - 122s 45ms/step - loss: 0.9965 - accuracy: 0.6463 - val_loss: 1.6414 - val_accuracy: 0.5103 - time: 486.2525 Epoch 5/10 2690/2690 [==============================] - 122s 45ms/step - loss: 0.8289 - accuracy: 0.7106 - val_loss: 1.6834 - val_accuracy: 0.5260 - time: 607.9542 Saving to ./experiments/25/sum.csv Train (172125,) Valid (19125,) Test (33750,) Epoch 1/10 2690/2690 [==============================] - 123s 45ms/step - loss: 2.0808 - accuracy: 0.3102 - val_loss: 1.6946 - val_accuracy: 0.4060 - time: 122.9349 Epoch 2/10 2690/2690 [==============================] - 121s 45ms/step - loss: 1.5866 - accuracy: 0.4425 - val_loss: 1.5601 - val_accuracy: 0.4671 - time: 244.0780 Epoch 3/10 2690/2690 [==============================] - 122s 45ms/step - loss: 1.3025 - accuracy: 0.5354 - val_loss: 1.5680 - val_accuracy: 0.4814 - time: 366.3781 Epoch 4/10 2690/2690 [==============================] - 120s 45ms/step - loss: 1.0726 - accuracy: 0.6133 - val_loss: 1.6285 - val_accuracy: 0.4982 - time: 486.3974 Epoch 5/10 2690/2690 [==============================] - 122s 45ms/step - loss: 0.9043 - accuracy: 0.6749 - val_loss: 1.6716 - val_accuracy: 0.5089 - time: 608.0671 Saving to ./experiments/26/sum.csv Train (172125,) Valid (19125,) Test (33750,) Epoch 1/10 2690/2690 [==============================] - 125s 46ms/step - loss: 2.2296 - accuracy: 0.2580 - val_loss: 1.8096 - val_accuracy: 0.3592 - time: 124.4856 Epoch 2/10 2690/2690 [==============================] - 123s 46ms/step - loss: 1.6584 - accuracy: 0.4205 - val_loss: 1.5306 - val_accuracy: 0.4859 - time: 248.0259 Epoch 3/10 2690/2690 [==============================] - 121s 45ms/step - loss: 1.3361 - accuracy: 0.5402 - val_loss: 1.4402 - val_accuracy: 0.5341 - time: 368.8101 Epoch 4/10 2690/2690 [==============================] - 122s 46ms/step - loss: 1.0787 - accuracy: 0.6410 - val_loss: 1.4426 - val_accuracy: 0.5686 - time: 491.3656 Epoch 5/10 2690/2690 [==============================] - 117s 44ms/step - loss: 0.8880 - accuracy: 0.7134 - val_loss: 1.4858 - val_accuracy: 0.5840 - time: 608.8162 Epoch 6/10 2690/2690 [==============================] - 120s 45ms/step - loss: 0.7560 - accuracy: 0.7655 - val_loss: 1.5254 - val_accuracy: 0.5940 - time: 729.1473 Saving to ./experiments/27/sum.csv Train (172125,) Valid (19125,) Test (33750,) Epoch 1/10 2690/2690 [==============================] - 123s 45ms/step - loss: 1.9839 - accuracy: 0.3485 - val_loss: 1.5318 - val_accuracy: 0.4723 - time: 122.9904 Epoch 2/10 2690/2690 [==============================] - 122s 46ms/step - loss: 1.4141 - accuracy: 0.5101 - val_loss: 1.3959 - val_accuracy: 0.5345 - time: 245.4391 Epoch 3/10 2690/2690 [==============================] - 122s 45ms/step - loss: 1.1213 - accuracy: 0.6072 - val_loss: 1.3805 - val_accuracy: 0.5586 - time: 367.5131 Epoch 4/10 2690/2690 [==============================] - 121s 45ms/step - loss: 0.9080 - accuracy: 0.6855 - val_loss: 1.4342 - val_accuracy: 0.5822 - time: 488.1874 Epoch 5/10 2690/2690 [==============================] - 124s 46ms/step - loss: 0.7550 - accuracy: 0.7454 - val_loss: 1.4890 - val_accuracy: 0.5930 - time: 611.9943 Epoch 6/10 2690/2690 [==============================] - 122s 45ms/step - loss: 0.6421 - accuracy: 0.7907 - val_loss: 1.5445 - val_accuracy: 0.6021 - time: 733.5118 Saving to ./experiments/28/sum.csv Train (172125,) Valid (19125,) Test (33750,) Epoch 1/10 2690/2690 [==============================] - 220s 81ms/step - loss: 2.2987 - accuracy: 0.2511 - val_loss: 2.0420 - val_accuracy: 0.3019 - time: 219.6768 Epoch 2/10 2690/2690 [==============================] - 214s 80ms/step - loss: 1.9066 - accuracy: 0.3432 - val_loss: 1.7355 - val_accuracy: 0.4036 - time: 433.9882 Epoch 3/10 2690/2690 [==============================] - 217s 81ms/step - loss: 1.5962 - accuracy: 0.4386 - val_loss: 1.6540 - val_accuracy: 0.4393 - time: 650.6270 Epoch 4/10 2690/2690 [==============================] - 221s 82ms/step - loss: 1.3669 - accuracy: 0.5185 - val_loss: 1.6590 - val_accuracy: 0.4497 - time: 871.8096 Epoch 5/10 2690/2690 [==============================] - 206s 77ms/step - loss: 1.1789 - accuracy: 0.5871 - val_loss: 1.6885 - val_accuracy: 0.4685 - time: 1078.1029 Epoch 6/10 2690/2690 [==============================] - 208s 77ms/step - loss: 1.0267 - accuracy: 0.6456 - val_loss: 1.7340 - val_accuracy: 0.4851 - time: 1286.4270 Saving to ./experiments/29/sum.csv Train (172125,) Valid (19125,) Test (33750,) Epoch 1/10 2690/2690 [==============================] - 212s 79ms/step - loss: 2.2158 - accuracy: 0.2619 - val_loss: 1.9355 - val_accuracy: 0.3273 - time: 212.3794 Epoch 2/10 2690/2690 [==============================] - 205s 76ms/step - loss: 1.8288 - accuracy: 0.3532 - val_loss: 1.7168 - val_accuracy: 0.3944 - time: 417.8752 Epoch 3/10 2690/2690 [==============================] - 210s 78ms/step - loss: 1.5736 - accuracy: 0.4344 - val_loss: 1.6661 - val_accuracy: 0.4275 - time: 627.8725 Epoch 4/10 2690/2690 [==============================] - 212s 79ms/step - loss: 1.3595 - accuracy: 0.5067 - val_loss: 1.6519 - val_accuracy: 0.4645 - time: 839.9593 Epoch 5/10 2690/2690 [==============================] - 207s 77ms/step - loss: 1.1728 - accuracy: 0.5773 - val_loss: 1.6700 - val_accuracy: 0.4849 - time: 1047.2926 Epoch 6/10 2690/2690 [==============================] - 214s 80ms/step - loss: 1.0239 - accuracy: 0.6329 - val_loss: 1.7395 - val_accuracy: 0.4991 - time: 1261.5815 Epoch 7/10 2690/2690 [==============================] - 214s 80ms/step - loss: 0.9046 - accuracy: 0.6801 - val_loss: 1.8413 - val_accuracy: 0.4966 - time: 1475.8997 Saving to ./experiments/30/sum.csv Train (172125,) Valid (19125,) Test (33750,) Epoch 1/10 2690/2690 [==============================] - 214s 79ms/step - loss: 2.7099 - accuracy: 0.1087 - val_loss: 2.7085 - val_accuracy: 0.0659 - time: 213.5496 Epoch 2/10 2690/2690 [==============================] - 215s 80ms/step - loss: 2.7083 - accuracy: 0.0674 - val_loss: 2.7085 - val_accuracy: 0.0658 - time: 428.6687 Epoch 3/10 2690/2690 [==============================] - 218s 81ms/step - loss: 2.7085 - accuracy: 0.0663 - val_loss: 2.7085 - val_accuracy: 0.0649 - time: 646.2720 Epoch 4/10 2690/2690 [==============================] - 223s 83ms/step - loss: 2.7084 - accuracy: 0.0661 - val_loss: 2.7084 - val_accuracy: 0.0649 - time: 869.2077 Epoch 5/10 2690/2690 [==============================] - 218s 81ms/step - loss: 2.7082 - accuracy: 0.0661 - val_loss: 2.7081 - val_accuracy: 0.0650 - time: 1086.8918 Epoch 6/10 2690/2690 [==============================] - 218s 81ms/step - loss: 2.4676 - accuracy: 0.1570 - val_loss: 2.2634 - val_accuracy: 0.2210 - time: 1304.6082 Epoch 7/10 2690/2690 [==============================] - 222s 82ms/step - loss: 2.1338 - accuracy: 0.2583 - val_loss: 2.0906 - val_accuracy: 0.2700 - time: 1526.2462 Epoch 8/10 2690/2690 [==============================] - 213s 79ms/step - loss: 1.9281 - accuracy: 0.3258 - val_loss: 1.9309 - val_accuracy: 0.3318 - time: 1739.2393 Epoch 9/10 2690/2690 [==============================] - 225s 84ms/step - loss: 1.6518 - accuracy: 0.4234 - val_loss: 1.7446 - val_accuracy: 0.4059 - time: 1964.0190 Epoch 10/10 2690/2690 [==============================] - 219s 82ms/step - loss: 1.4152 - accuracy: 0.5107 - val_loss: 1.7144 - val_accuracy: 0.4340 - time: 2183.3675 Saving to ./experiments/31/sum.csv Train (172125,) Valid (19125,) Test (33750,) Epoch 1/10 2690/2690 [==============================] - 229s 85ms/step - loss: 2.1107 - accuracy: 0.2927 - val_loss: 1.7549 - val_accuracy: 0.3950 - time: 229.2097 Epoch 2/10 2690/2690 [==============================] - 212s 79ms/step - loss: 1.6085 - accuracy: 0.4436 - val_loss: 1.4837 - val_accuracy: 0.5008 - time: 441.2030 Epoch 3/10 2690/2690 [==============================] - 215s 80ms/step - loss: 1.2858 - accuracy: 0.5538 - val_loss: 1.4181 - val_accuracy: 0.5407 - time: 655.9851 Epoch 4/10 2690/2690 [==============================] - 217s 81ms/step - loss: 1.0564 - accuracy: 0.6399 - val_loss: 1.4044 - val_accuracy: 0.5668 - time: 873.4003 Epoch 5/10 2690/2690 [==============================] - 219s 81ms/step - loss: 0.8795 - accuracy: 0.7080 - val_loss: 1.4369 - val_accuracy: 0.5816 - time: 1092.1185 Epoch 6/10 2690/2690 [==============================] - 217s 81ms/step - loss: 0.7448 - accuracy: 0.7617 - val_loss: 1.4501 - val_accuracy: 0.5995 - time: 1309.1906 Epoch 7/10 2690/2690 [==============================] - 221s 82ms/step - loss: 0.6385 - accuracy: 0.8015 - val_loss: 1.4799 - val_accuracy: 0.6026 - time: 1530.5399 Saving to ./experiments/32/sum.csv
def run_rnn_model(
max_tokens,
output_sequence_length,
number_of_authors,
emb_size,
key,
loss,
optimizer,
metrics,
batch_size,
epochs,
lr
):
MODEL_NAME = "Bidirectional GRU"
current_path = setup_directory()
current_data = data[str(number_of_authors)][key]
loader = get_load_path_53 if number_of_authors == 5 else get_load_path_153
encoder = create_encoder_from_path(loader(AUTHORS_FILENAME))
X_train, X_valid, X_test, y_train, y_valid, y_test = split_dataframe_to_train_test_valid(current_data)
y_test = encoder.transform(y_test)
y_train = encoder.transform(y_train)
y_valid = encoder.transform(y_valid)
train_ds = create_dataset_from_Xy(X_train, y_train)
test_ds = create_dataset_from_Xy(X_test, y_test)
valid_ds = create_dataset_from_Xy(X_valid, y_valid)
vector_layer = tf.keras.layers.TextVectorization(
max_tokens=max_tokens,
output_mode='int',
standardize=None,
output_sequence_length=output_sequence_length,
split='whitespace'
)
vector_layer.adapt(train_ds.map(lambda x, y: x))
model = tf.keras.Sequential()
model.add(tf.keras.Input(shape=(1,), dtype=tf.string))
model.add(vector_layer)
model.add(
tf.keras.layers.Embedding(
max_tokens + 1,
emb_size,
mask_zero = True
)
)
model.add(tf.keras.layers.Bidirectional(
tf.keras.layers.LSTM(64, activation='relu', return_sequences=True, dropout=0.2, recurrent_dropout=0.2)
))
model.add(
tf.keras.layers.GRU(64, activation='relu', return_sequences=False)
)
model.add(tf.keras.layers.Dense(32, activation='relu'))
model.add(tf.keras.layers.Dropout(rate=0.2))
model.add(tf.keras.layers.Dense(32, activation='relu'))
model.add(tf.keras.layers.Dropout(rate=0.3))
model.add(tf.keras.layers.Dense(64, activation='relu'))
model.add(tf.keras.layers.Dense(number_of_authors, activation='softmax'))
optimizer = optimizer(learning_rate=lr)
model.compile(
loss=loss,
optimizer=optimizer,
metrics=metrics,
)
history = model.fit(
train_ds.batch(batch_size),
validation_data=valid_ds.batch(1),
epochs=epochs,
callbacks=[
CSVLogger(current_path),
es
]
)
prediction = model.predict(test_ds.batch(1))
y_pred = prediction_to_labels(prediction)
accuracy = accuracy_score(y_test, y_pred)
conf_matrix = confusion_matrix(y_test, y_pred)
return save_experiment_info(
current_path,
ModelName=MODEL_NAME,
BatchSize=batch_size,
Optimizer=type(optimizer).__name__,
Epochs=epochs,
EmbeddingSize=emb_size,
Time=BLANK,
Accuracy=accuracy,
LR=lr,
Hits=0,
Miss=0,
Key=key,
SeqLen=output_sequence_length,
VocabSize=max_tokens,
TrainableEmbedding=True,
ConfMatrix=conf_matrix,
ModelType="NORMAL",
TransformerName=BLANK,
NumberOfAuthors=number_of_authors
)
def generate_model_rnn_experiments():
for embedding_size in EMB_SIZES:
for vocab_size in [10000]:
for author in [5, 15]:
for seq_len in [300]:
for key in ALL_KEYS:
for optimizer in [ADAM]:
for batch_size in BATCH_SIZES:
for epoch in [EPOCHS]:
for lr in [LR]:
yield lr, embedding_size, vocab_size, seq_len, key, optimizer, batch_size, epoch, author
len(list(generate_model_rnn_experiments()))
16
for exp_values in generate_model_rnn_experiments():
lr, embedding_size, vocab_size, seq_len, key, optimizer, batch_size, epoch, author = exp_values
run_rnn_model(
max_tokens=vocab_size,
output_sequence_length=seq_len,
number_of_authors=author,
emb_size=embedding_size,
key=key,
loss=LOSS,
optimizer=optimizer,
metrics=METRICS,
batch_size=batch_size,
epochs=epoch,
lr=lr
)
Train (57375,) Valid (6375,) Test (11250,) Epoch 1/10 897/897 [==============================] - 1093s 1s/step - loss: 6687085056.0000 - accuracy: 0.3819 - val_loss: nan - val_accuracy: 0.2941 - time: 1092.7117 Epoch 2/10 897/897 [==============================] - 1146s 1s/step - loss: 1187478372352.0000 - accuracy: 0.4749 - val_loss: 1.0665 - val_accuracy: 0.5434 - time: 2239.0744 Epoch 3/10 897/897 [==============================] - 1082s 1s/step - loss: 1.3208 - accuracy: 0.5978 - val_loss: 0.9181 - val_accuracy: 0.6209 - time: 3320.7732 Epoch 4/10 897/897 [==============================] - 1050s 1s/step - loss: 5227681792.0000 - accuracy: 0.6492 - val_loss: 0.8942 - val_accuracy: 0.6402 - time: 4371.0521 Epoch 5/10 897/897 [==============================] - 1045s 1s/step - loss: 0.8841 - accuracy: 0.6834 - val_loss: 0.8578 - val_accuracy: 0.6551 - time: 5416.5511 Epoch 6/10 897/897 [==============================] - 1078s 1s/step - loss: 0.7737 - accuracy: 0.6969 - val_loss: 0.8511 - val_accuracy: 0.6684 - time: 6494.4732 Epoch 7/10 897/897 [==============================] - 1319s 1s/step - loss: 1.5510 - accuracy: 0.7337 - val_loss: 0.8364 - val_accuracy: 0.6836 - time: 7813.6105 Epoch 8/10 897/897 [==============================] - 1404s 2s/step - loss: 0.6452 - accuracy: 0.7529 - val_loss: 0.8382 - val_accuracy: 0.6761 - time: 9240.8331 Epoch 9/10 897/897 [==============================] - 1337s 1s/step - loss: 0.6104 - accuracy: 0.7729 - val_loss: 0.8286 - val_accuracy: 0.6885 - time: 10577.5128 Epoch 10/10 897/897 [==============================] - 1290s 1s/step - loss: 0.5759 - accuracy: 0.7884 - val_loss: 0.8020 - val_accuracy: 0.7101 - time: 11867.2054 Saving to ./experiments/32/sum.csv Train (57375,) Valid (6375,) Test (11250,) Epoch 1/10 897/897 [==============================] - 1369s 2s/step - loss: 2406436.7500 - accuracy: 0.3964 - val_loss: 1.1757 - val_accuracy: 0.4856 - time: 1368.7052 Epoch 2/10 897/897 [==============================] - 1432s 2s/step - loss: 25465.0742 - accuracy: 0.5265 - val_loss: 1.1074 - val_accuracy: 0.5081 - time: 2801.0649 Epoch 3/10 897/897 [==============================] - 1506s 2s/step - loss: nan - accuracy: 0.3837 - val_loss: nan - val_accuracy: 0.1922 - time: 4307.1053 Epoch 4/10 897/897 [==============================] - 1382s 2s/step - loss: nan - accuracy: 0.2015 - val_loss: nan - val_accuracy: 0.1922 - time: 5688.8658 Epoch 5/10 897/897 [==============================] - 1845s 2s/step - loss: nan - accuracy: 0.2015 - val_loss: nan - val_accuracy: 0.1922 - time: 7534.0251 Saving to ./experiments/33/sum.csv Train (57375,) Valid (6375,) Test (11250,) Epoch 1/10 897/897 [==============================] - 1795s 2s/step - loss: 1.3198 - accuracy: 0.3473 - val_loss: 1.1397 - val_accuracy: 0.4634 - time: 1794.7318 Epoch 2/10 897/897 [==============================] - 1556s 2s/step - loss: 183.8856 - accuracy: 0.5911 - val_loss: 0.7558 - val_accuracy: 0.7106 - time: 3350.3680 Epoch 3/10 897/897 [==============================] - 1191s 1s/step - loss: 0.6741 - accuracy: 0.7487 - val_loss: 0.7135 - val_accuracy: 0.7355 - time: 4541.5603 Epoch 4/10 897/897 [==============================] - 1185s 1s/step - loss: 0.5505 - accuracy: 0.7981 - val_loss: 202.5703 - val_accuracy: 0.7536 - time: 5726.7569 Epoch 5/10 897/897 [==============================] - 1184s 1s/step - loss: nan - accuracy: 0.2459 - val_loss: nan - val_accuracy: 0.1922 - time: 6911.1217 Epoch 6/10 897/897 [==============================] - 1180s 1s/step - loss: nan - accuracy: 0.2015 - val_loss: nan - val_accuracy: 0.1922 - time: 8091.5012 Saving to ./experiments/34/sum.csv Train (57375,) Valid (6375,) Test (11250,) Epoch 1/10 897/897 [==============================] - 1202s 1s/step - loss: 278055.2812 - accuracy: 0.3279 - val_loss: 1.2133 - val_accuracy: 0.4411 - time: 1202.2194 Epoch 2/10 897/897 [==============================] - 1192s 1s/step - loss: 1.4219 - accuracy: 0.4619 - val_loss: 1.1394 - val_accuracy: 0.4924 - time: 2394.1133 Epoch 3/10 897/897 [==============================] - 1187s 1s/step - loss: 31504.0977 - accuracy: 0.5164 - val_loss: 1.1058 - val_accuracy: 0.5269 - time: 3581.0423 Epoch 4/10 897/897 [==============================] - 1184s 1s/step - loss: 1.0101 - accuracy: 0.5739 - val_loss: 1.0839 - val_accuracy: 0.5523 - time: 4764.7697 Epoch 5/10 897/897 [==============================] - 1191s 1s/step - loss: 0.9372 - accuracy: 0.6264 - val_loss: 0.8990 - val_accuracy: 0.6436 - time: 5955.9341 Epoch 6/10 897/897 [==============================] - 1185s 1s/step - loss: 0.7581 - accuracy: 0.6962 - val_loss: 0.8012 - val_accuracy: 0.6907 - time: 7140.5361 Epoch 7/10 897/897 [==============================] - 1188s 1s/step - loss: 0.6762 - accuracy: 0.7385 - val_loss: 0.7789 - val_accuracy: 0.7209 - time: 8328.5538 Epoch 8/10 897/897 [==============================] - 1192s 1s/step - loss: 0.5939 - accuracy: 0.7797 - val_loss: 0.7381 - val_accuracy: 0.7454 - time: 9520.1441 Epoch 9/10 897/897 [==============================] - 1187s 1s/step - loss: 666630.4375 - accuracy: 0.7840 - val_loss: 0.7031 - val_accuracy: 0.7437 - time: 10707.5617 Epoch 10/10 897/897 [==============================] - 1185s 1s/step - loss: 0.5122 - accuracy: 0.8174 - val_loss: 0.7276 - val_accuracy: 0.7478 - time: 11892.9372 Saving to ./experiments/35/sum.csv Train (172125,) Valid (19125,) Test (33750,) Epoch 1/10 2690/2690 [==============================] - 3581s 1s/step - loss: nan - accuracy: 0.0903 - val_loss: nan - val_accuracy: 0.0685 - time: 3581.0230 Epoch 2/10 2690/2690 [==============================] - 3662s 1s/step - loss: nan - accuracy: 0.0661 - val_loss: nan - val_accuracy: 0.0685 - time: 7243.3928 Epoch 3/10 2690/2690 [==============================] - 6001s 2s/step - loss: nan - accuracy: 0.0661 - val_loss: nan - val_accuracy: 0.0685 - time: 13244.8738 Saving to ./experiments/36/sum.csv Train (172125,) Valid (19125,) Test (33750,) Epoch 1/10 2690/2690 [==============================] - 5273s 2s/step - loss: nan - accuracy: 0.0667 - val_loss: nan - val_accuracy: 0.0685 - time: 5272.8774 Epoch 2/10 2690/2690 [==============================] - 4828s 2s/step - loss: nan - accuracy: 0.0661 - val_loss: nan - val_accuracy: 0.0685 - time: 10100.6725 Epoch 3/10 2690/2690 [==============================] - 4549s 2s/step - loss: nan - accuracy: 0.0661 - val_loss: nan - val_accuracy: 0.0685 - time: 14649.6744 Saving to ./experiments/37/sum.csv Train (172125,) Valid (19125,) Test (33750,) Epoch 1/10 2690/2690 [==============================] - 4745s 2s/step - loss: 143974.9844 - accuracy: 0.1826 - val_loss: 2.1231 - val_accuracy: 0.2460 - time: 4745.0529 Epoch 2/10 1493/2690 [===============>..............] - ETA: 17:04 - loss: 2.0408 - accuracy: 0.2720
class TransformerName(Enum):
DistilBertBaseUncased = "distilbert-base-uncased"
BertBaseUncased = "bert-base-uncased"
ElectraSmall = "google/electra-small-discriminator"
Blank = BLANK_DESCRIPTION
from transformers import TFAutoModel
from transformers import AutoTokenizer
def tokenize(sentences, tokenizer, max_length, padding='max_length'):
return tokenizer(
sentences,
truncation=True,
padding=padding,
max_length=max_length,
return_tensors="tf"
)
def run_transformer_model(
transformer_name,
output_sequence_length,
number_of_authors,
key,
loss,
optimizer,
metrics,
batch_size,
epochs,
lr
):
MODEL_NAME = "Transformer"
current_path = setup_directory()
tokenizer = AutoTokenizer.from_pretrained(transformer_name)
current_data = data[str(number_of_authors)][key]
loader = get_load_path_53 if number_of_authors == 5 else get_load_path_153
encoder = create_encoder_from_path(loader(AUTHORS_FILENAME))
X_train, X_valid, X_test, y_train, y_valid, y_test = split_dataframe_to_train_test_valid(current_data)
y_test = encoder.transform(y_test)
y_train = encoder.transform(y_train)
y_valid = encoder.transform(y_valid)
train_ds = tf.data.Dataset.from_tensor_slices((
dict(tokenize(list(X_train), tokenizer, output_sequence_length)),
y_train
)).batch(batch_size).prefetch(1)
valid_ds = tf.data.Dataset.from_tensor_slices((
dict(tokenize(list(X_valid), tokenizer, output_sequence_length)),
y_valid
)).batch(batch_size).prefetch(1)
test_ds = tf.data.Dataset.from_tensor_slices((
dict(tokenize(list(X_test), tokenizer, output_sequence_length)),
y_test
)).batch(1).prefetch(1)
base = TFAutoModel.from_pretrained(
transformer_name,
)
input_ids = tf.keras.layers.Input(shape=(output_sequence_length,), dtype=tf.int32, name='input_ids')
attention_mask = tf.keras.layers.Input((output_sequence_length,), dtype=tf.int32, name='attention_mask')
#Selection of cls
output = base([input_ids, attention_mask]).last_hidden_state[:, 0, :]
output = tf.keras.layers.Dropout(
rate=0.15,
)(output)
output = tf.keras.layers.Dense(
units=64,
activation='relu',
)(output)
output = tf.keras.layers.BatchNormalization()(output)
output = tf.keras.layers.Dense(
units=64,
activation='relu',
)(output)
output = tf.keras.layers.BatchNormalization()(output)
output_layer = tf.keras.layers.Dense(
units=number_of_authors,
activation='softmax'
)(output)
model = tf.keras.Model(inputs=[input_ids, attention_mask], outputs=output_layer)
model.summary()
optimizer = optimizer(learning_rate=lr)
model.compile(
loss=loss,
optimizer=optimizer,
metrics=metrics,
)
history = model.fit(
train_ds,
validation_data=valid_ds,
epochs=epochs,
callbacks=[
CSVLogger(current_path),
es
]
)
prediction = model.predict(test_ds)
y_pred = prediction_to_labels(prediction)
accuracy = accuracy_score(y_test, y_pred)
conf_matrix = confusion_matrix(y_test, y_pred)
return save_experiment_info(
current_path,
ModelName=MODEL_NAME,
BatchSize=batch_size,
Optimizer=type(optimizer).__name__,
Epochs=epochs,
EmbeddingSize=BLANK,
Time=BLANK,
Accuracy=accuracy,
LR=lr,
Hits=0,
Miss=0,
Key=key,
SeqLen=output_sequence_length,
VocabSize=BLANK,
TrainableEmbedding=True,
ConfMatrix=conf_matrix,
ModelType="TL",
TransformerName=transformer_name,
NumberOfAuthors=number_of_authors
)
def generate_model_transformer_experiments():
for transformer_name in [TransformerName.DistilBertBaseUncased.value, TransformerName.BertBaseUncased.value, TransformerName.ElectraSmall.value ]:
for author in [5]:
for seq_len in [300]:
for key in ["LOWER_I"]:
for optimizer in [ADAM]:
for batch_size in [128]:
for epoch in [3]:
for lr in TRANSFORMER_LR:
yield (
transformer_name,
seq_len,
author,
key,
LOSS,
optimizer,
METRICS,
batch_size,
epoch,
lr
)
list(generate_model_transformer_experiments())
[('distilbert-base-uncased',
300,
5,
'LOWER_I',
<keras.losses.SparseCategoricalCrossentropy at 0x7fe61924cf70>,
keras.optimizer_v2.adam.Adam,
[<keras.metrics.SparseCategoricalAccuracy at 0x7fe319979670>],
128,
3,
0.001),
('distilbert-base-uncased',
300,
5,
'LOWER_I',
<keras.losses.SparseCategoricalCrossentropy at 0x7fe61924cf70>,
keras.optimizer_v2.adam.Adam,
[<keras.metrics.SparseCategoricalAccuracy at 0x7fe319979670>],
128,
3,
5e-05),
('bert-base-uncased',
300,
5,
'LOWER_I',
<keras.losses.SparseCategoricalCrossentropy at 0x7fe61924cf70>,
keras.optimizer_v2.adam.Adam,
[<keras.metrics.SparseCategoricalAccuracy at 0x7fe319979670>],
128,
3,
0.001),
('bert-base-uncased',
300,
5,
'LOWER_I',
<keras.losses.SparseCategoricalCrossentropy at 0x7fe61924cf70>,
keras.optimizer_v2.adam.Adam,
[<keras.metrics.SparseCategoricalAccuracy at 0x7fe319979670>],
128,
3,
5e-05),
('google/electra-small-discriminator',
300,
5,
'LOWER_I',
<keras.losses.SparseCategoricalCrossentropy at 0x7fe61924cf70>,
keras.optimizer_v2.adam.Adam,
[<keras.metrics.SparseCategoricalAccuracy at 0x7fe319979670>],
128,
3,
0.001),
('google/electra-small-discriminator',
300,
5,
'LOWER_I',
<keras.losses.SparseCategoricalCrossentropy at 0x7fe61924cf70>,
keras.optimizer_v2.adam.Adam,
[<keras.metrics.SparseCategoricalAccuracy at 0x7fe319979670>],
128,
3,
5e-05)]
len(list(generate_model_transformer_experiments()))
6
for exp_values in generate_model_transformer_experiments():
run_transformer_model(
*exp_values
)
Train (57375,) Valid (6375,) Test (11250,)
Some layers from the model checkpoint at distilbert-base-uncased were not used when initializing TFDistilBertModel: ['vocab_projector', 'vocab_transform', 'activation_13', 'vocab_layer_norm'] - This IS expected if you are initializing TFDistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model). - This IS NOT expected if you are initializing TFDistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model). All the layers of TFDistilBertModel were initialized from the model checkpoint at distilbert-base-uncased. If your task is similar to the task the model of the checkpoint was trained on, you can already use TFDistilBertModel for predictions without further training.
Model: "model_4"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_ids (InputLayer) [(None, 300)] 0 []
attention_mask (InputLayer) [(None, 300)] 0 []
tf_distil_bert_model_5 (TFDist TFBaseModelOutput(l 66362880 ['input_ids[0][0]',
ilBertModel) ast_hidden_state=(N 'attention_mask[0][0]']
one, 300, 768),
hidden_states=None
, attentions=None)
tf.__operators__.getitem_5 (Sl (None, 768) 0 ['tf_distil_bert_model_5[0][0]']
icingOpLambda)
dropout_309 (Dropout) (None, 768) 0 ['tf.__operators__.getitem_5[0][0
]']
dense_291 (Dense) (None, 64) 49216 ['dropout_309[0][0]']
batch_normalization_11 (BatchN (None, 64) 256 ['dense_291[0][0]']
ormalization)
dense_292 (Dense) (None, 64) 4160 ['batch_normalization_11[0][0]']
batch_normalization_12 (BatchN (None, 64) 256 ['dense_292[0][0]']
ormalization)
dense_293 (Dense) (None, 5) 325 ['batch_normalization_12[0][0]']
==================================================================================================
Total params: 66,417,093
Trainable params: 66,416,837
Non-trainable params: 256
__________________________________________________________________________________________________
Epoch 1/3
449/449 [==============================] - 12732s 28s/step - loss: 1.7213 - accuracy: 0.1676 - val_loss: 1.6598 - val_accuracy: 0.1987 - time: 12731.7159
Epoch 2/3
314/449 [===================>..........] - ETA: 1:01:37 - loss: 1.6307 - accuracy: 0.2010
from os.path import exists as file_exists
filenames = [
SUMMAR,
LOG,
]
class Storage:
def __init__(self):
self.records = []
def reset(self):
self.records = []
def run(self, directory=None):
self.directory = directory
if self.directory is None:
return
create_dataframe(self.directory, self.records)
def get_dataframe(self):
mapped = map(lambda x: x.iloc[1, :].values, self.records)
df = pd.DataFrame(mapped)
new_header = self.records[0].iloc[0, :].values
df.columns = new_header
return df
def create_dataframe(start_directory, storage=None):
return process_directory(start_directory, storage)
def process_directory(directory, storage=None):
is_correct = is_correct_file(directory)
record = None
if is_correct:
if storage is not None:
record = create_record(directory)
if record is not None:
storage.append(record)
for current_directory in os.listdir(directory):
deeper_level = os.path.sep.join([directory, current_directory])
if os.path.isdir(deeper_level):
process_directory(deeper_level, storage)
def is_correct_file(path):
for filename in filenames:
current_path = os.path.sep.join([path, filename])
if os.path.exists(current_path):
return True
return False
def exists(directory, filename):
current_path = os.path.sep.join([directory, filename])
if file_exists(current_path):
return current_path
return None
def create_record(directory):
try:
log = parse_log(directory)
summ = parse_summa(directory)
record = merge_content(
summ,
log
)
return record
except Exception as e:
print(f"Exception in {directory}")
print(f"Exception {e}")
return None
def parse_summa(directory):
path = exists(directory, SUMMAR)
if path is None:
return None
content = pd.read_csv(path, sep=";", header=None)
return content
def parse_log(directory):
path = exists(directory, LOG)
if path is None:
return None
content = pd.read_csv(path, sep=";")
dic = {}
for index in range(content.shape[1]):
key = content.columns[index]
if "Unnamed" not in key:
dic[key] = [content.iloc[:, index].values]
res = pd.DataFrame.from_dict(dic, orient="index").reset_index()
res.columns = [0, 1]
return res
def merge_content(log=pd.DataFrame(), summ=pd.DataFrame()):
concat_df = pd.concat([summ, log])
record = concat_df.T
return record
start_directory = os.path.sep.join(EXPERIMENTS_SAVE_DIRECTORY)
storage = Storage()
storage.run(start_directory)
pd.set_option('display.max_columns', None)
df = storage.get_dataframe()
df.index = list(range(len(df)))
len(df)
53
import math
df['Accuracy'] = list(map(lambda x: round(float(x), 2), df['Accuracy'].values))
df
| loss | accuracy | val_loss | val_accuracy | time | NaN | ModelName | BatchSize | Optimizer | LR | Epochs | EmbeddingSize | Time | Accuracy | Hits | Miss | Key | SeqLen | VocabSize | TrainableEmbedding | ConfMatrix | Type | TransformerName | NumberOfAuthors | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | [1.1091737747192385, 0.545360803604126, 0.3251... | [0.545690655708313, 0.8149368166923523, 0.8950... | [0.6624863743782043, 0.658841609954834, 0.7480... | [0.7519999742507935, 0.774117648601532, 0.7775... | [19.60962748527527, 38.74808216094971, 57.6085... | 0 | DENSE | 64 | Adam | 0.001 | - | 50 | - | 0.77 | 0 | 0 | RAW | 200 | 10000 | True | [[1678 158 256 36 85]\n [ 172 1459 199 ... | NORMAL | - | 5 |
| 1 | [1.1327065229415894, 0.5607202053070068, 0.323... | [0.5527529120445251, 0.807494580745697, 0.8943... | [0.67454594373703, 0.6359243392944336, 0.74231... | [0.7509019374847412, 0.779764711856842, 0.7863... | [20.090360403060917, 39.35527753829956, 58.566... | 0 | DENSE | 64 | Adam | 0.001 | - | 50 | - | 0.77 | 0 | 0 | LOWER | 200 | 10000 | True | [[1488 306 238 45 136]\n [ 85 1690 127 ... | NORMAL | - | 5 |
| 2 | [1.225493311882019, 0.6993909478187561, 0.4104... | [0.4702745079994201, 0.7103093862533569, 0.854... | [0.8190874457359314, 0.6555845737457275, 0.699... | [0.6439215540885925, 0.7667450904846191, 0.792... | [18.983893871307373, 37.27286505699158, 55.811... | 0 | DENSE | 64 | Adam | 0.001 | - | 50 | - | 0.77 | 0 | 0 | DEFAULT | 200 | 10000 | True | [[1451 183 376 77 126]\n [ 280 1587 93 ... | NORMAL | - | 5 |
| 3 | [1.1457595825195312, 0.5117649435997009, 0.299... | [0.5310588479042053, 0.8236165642738342, 0.901... | [0.6402604579925537, 0.5437538027763367, 0.646... | [0.7640784382820129, 0.8133333325386047, 0.811... | [19.854929447174072, 38.84715795516968, 58.101... | 0 | DENSE | 64 | Adam | 0.001 | - | 50 | - | 0.81 | 0 | 0 | LOWER_I | 200 | 10000 | True | [[1670 150 205 51 137]\n [ 198 1715 102 ... | NORMAL | - | 5 |
| 4 | [1.2395521402359009, 0.6198447942733765, 0.376... | [0.4878901839256286, 0.7835119962692261, 0.877... | [0.7236012816429138, 0.5898836255073547, 0.644... | [0.7349019646644592, 0.7899608016014099, 0.790... | [24.031248092651367, 47.32864117622376, 70.390... | 0 | DENSE | 64 | Adam | 0.001 | - | 50 | - | 0.79 | 0 | 0 | RAW | 400 | 10000 | True | [[1691 143 276 33 70]\n [ 141 1608 183 ... | NORMAL | - | 5 |
| 5 | [1.256301760673523, 0.6661115288734436, 0.3901... | [0.4733960926532745, 0.7623529434204102, 0.872... | [0.8248462677001953, 0.6136951446533203, 0.676... | [0.6912941336631775, 0.7836862802505493, 0.781... | [23.661153078079224, 46.77628660202026, 69.989... | 0 | DENSE | 64 | Adam | 0.001 | - | 50 | - | 0.77 | 0 | 0 | LOWER | 400 | 10000 | True | [[1750 172 193 31 67]\n [ 198 1542 185 ... | NORMAL | - | 5 |
| 6 | [1.4970093965530396, 0.8270591497421265, 0.494... | [0.3286588191986084, 0.6616818904876709, 0.820... | [1.0638797283172607, 0.6806007027626038, 0.622... | [0.562666654586792, 0.7440000176429749, 0.7909... | [22.825047254562374, 44.74828147888184, 66.949... | 0 | DENSE | 64 | Adam | 0.001 | - | 50 | - | 0.79 | 0 | 0 | DEFAULT | 400 | 10000 | True | [[1574 149 312 64 114]\n [ 230 1586 166 ... | NORMAL | - | 5 |
| 7 | [1.224252700805664, 0.5908034443855286, 0.3403... | [0.4824941158294678, 0.7896470427513123, 0.886... | [0.7540565133094788, 0.5498793125152588, 0.583... | [0.7225098013877869, 0.8059607744216919, 0.817... | [23.619723081588745, 46.35664224624634, 68.851... | 0 | DENSE | 64 | Adam | 0.001 | - | 50 | - | 0.80 | 0 | 0 | LOWER_I | 400 | 10000 | True | [[1559 158 332 46 118]\n [ 127 1619 214 ... | NORMAL | - | 5 |
| 8 | [2.038879871368408, 1.563637614250183, 1.32510... | [0.3195910453796386, 0.4616528749465942, 0.546... | [1.6314889192581177, 1.4994858503341677, 1.474... | [0.4371764659881592, 0.4931764602661133, 0.517... | [58.19883918762207, 116.40550565719604, 174.34... | 0 | DENSE | 64 | Adam | 0.001 | - | 50 | - | 0.53 | 0 | 0 | RAW | 200 | 10000 | True | [[1382 246 16 48 47 53 275 73 19... | NORMAL | - | 15 |
| 9 | [2.007798671722412, 1.554110050201416, 1.32113... | [0.3344993591308594, 0.4640406668186188, 0.547... | [1.6179845333099363, 1.4908241033554075, 1.485... | [0.4401045739650726, 0.4930196106433868, 0.519... | [57.74453091621399, 115.19541382789612, 171.85... | 0 | DENSE | 64 | Adam | 0.001 | - | 50 | - | 0.53 | 0 | 0 | LOWER | 200 | 10000 | True | [[1182 218 46 45 52 60 311 104 10... | NORMAL | - | 15 |
| 10 | [2.0945076942443848, 1.5407600402832031, 1.260... | [0.3058666586875915, 0.4837879538536072, 0.582... | [1.6695539951324463, 1.3899223804473877, 1.346... | [0.4403137266635895, 0.5398169755935669, 0.576... | [55.19337034225464, 109.7722339630127, 164.612... | 0 | DENSE | 64 | Adam | 0.001 | - | 50 | - | 0.58 | 0 | 0 | DEFAULT | 200 | 10000 | True | [[1521 213 24 13 44 56 146 17 9... | NORMAL | - | 15 |
| 11 | [2.0426511764526367, 1.4889464378356934, 1.221... | [0.331916332244873, 0.4952156841754913, 0.5902... | [1.6068023443222046, 1.3842854499816897, 1.342... | [0.4554771184921264, 0.5311895608901978, 0.570... | [57.66135025024414, 114.47401738166808, 171.94... | 0 | DENSE | 64 | Adam | 0.001 | - | 50 | - | 0.58 | 0 | 0 | LOWER_I | 200 | 10000 | True | [[1535 153 14 8 31 47 290 53 11... | NORMAL | - | 15 |
| 12 | [2.0767838954925537, 1.5745370388031006, 1.335... | [0.3215215802192688, 0.4560929536819458, 0.542... | [1.649091720581055, 1.498886227607727, 1.46961... | [0.4290196001529693, 0.4890457391738891, 0.523... | [68.15565466880798, 133.44615292549133, 200.59... | 0 | DENSE | 64 | Adam | 0.001 | - | 50 | - | 0.53 | 0 | 0 | RAW | 400 | 10000 | True | [[1258 179 48 47 35 38 323 37 10... | NORMAL | - | 15 |
| 13 | [2.10368275642395, 1.6013576984405518, 1.34118... | [0.3073516488075256, 0.4520435631275177, 0.544... | [1.697721242904663, 1.4982041120529177, 1.4765... | [0.4145882427692413, 0.4970980286598205, 0.525... | [69.45309066772461, 136.9076645374298, 204.584... | 0 | DENSE | 64 | Adam | 0.001 | - | 50 | - | 0.53 | 0 | 0 | LOWER | 400 | 10000 | True | [[1307 134 32 37 30 72 318 55 22... | NORMAL | - | 15 |
| 14 | [2.148930549621582, 1.7112888097763062, 1.4490... | [0.2799529433250427, 0.3945388495922088, 0.491... | [1.811055302619934, 1.5882717370986938, 1.4911... | [0.3535163402557373, 0.4577777683734894, 0.499... | [64.82900762557983, 130.9426231384277, 195.940... | 0 | DENSE | 64 | Adam | 0.001 | - | 50 | - | 0.55 | 0 | 0 | DEFAULT | 400 | 10000 | True | [[1439 88 23 22 27 74 353 36 18... | NORMAL | - | 15 |
| 15 | [2.1052839756011963, 1.5557817220687866, 1.270... | [0.3025882244110107, 0.4585562944412231, 0.561... | [1.6754242181777954, 1.404281497001648, 1.3419... | [0.414274513721466, 0.5210980176925659, 0.5575... | [68.50114560127258, 136.1558701992035, 204.361... | 0 | DENSE | 64 | Adam | 0.001 | - | 50 | - | 0.56 | 0 | 0 | LOWER_I | 400 | 10000 | True | [[1434 184 20 21 17 45 317 107 14... | NORMAL | - | 15 |
| 16 | [1.2060670852661133, 0.5876715183258057, 0.301... | [0.5213333368301392, 0.7960087060928345, 0.901... | [0.7380061745643616, 0.6602489352226257, 0.777... | [0.722196102142334, 0.7658039331436157, 0.7722... | [42.88776779174805, 84.30938196182251, 125.639... | 0 | DENSE | 64 | Adam | 0.001 | - | 300 | - | 0.76 | 0 | 0 | RAW | 200 | 10000 | True | [[1499 249 353 35 77]\n [ 109 1561 244 ... | NORMAL | - | 5 |
| 17 | [1.1218571662902832, 0.5309804677963257, 0.279... | [0.5644392371177673, 0.8198866844177246, 0.911... | [0.6812100410461426, 0.6584259867668152, 0.761... | [0.7573333382606506, 0.7799215912818909, 0.788... | [42.935046672821045, 84.6618320941925, 126.400... | 0 | DENSE | 64 | Adam | 0.001 | - | 300 | - | 0.76 | 0 | 0 | LOWER | 200 | 10000 | True | [[1526 250 228 57 152]\n [ 111 1533 188 ... | NORMAL | - | 5 |
| 18 | [1.2983185052871704, 0.67186439037323, 0.35850... | [0.4403764605522156, 0.7503268122673035, 0.875... | [0.8450406789779663, 0.6062559485435486, 0.653... | [0.6599215865135193, 0.7830588221549988, 0.803... | [41.16602444648743, 81.34677410125732, 120.331... | 0 | DENSE | 64 | Adam | 0.001 | - | 300 | - | 0.79 | 0 | 0 | DEFAULT | 200 | 10000 | True | [[1461 291 227 76 158]\n [ 174 1726 94 ... | NORMAL | - | 5 |
| 19 | [1.1131271123886108, 0.4970025420188904, 0.263... | [0.5663529634475708, 0.8301350474357605, 0.912... | [0.634172260761261, 0.5626846551895142, 0.6592... | [0.7694117426872253, 0.8014117479324341, 0.809... | [42.32026171684265, 83.03080439567566, 123.852... | 0 | DENSE | 64 | Adam | 0.001 | - | 300 | - | 0.80 | 0 | 0 | LOWER_I | 200 | 10000 | True | [[1508 290 287 32 96]\n [ 77 1770 190 ... | NORMAL | - | 5 |
| 20 | [1.4313589334487915, 0.976482629776001, 0.5856... | [0.3718431293964386, 0.5906579494476318, 0.785... | [1.1860140562057495, 0.7914318442344666, 0.690... | [0.492078423500061, 0.7022745013237, 0.7548235... | [75.83154845237732, 148.4723603725433, 221.875... | 0 | DENSE | 64 | Adam | 0.001 | - | 300 | - | 0.76 | 0 | 0 | RAW | 400 | 10000 | True | [[1521 245 329 36 82]\n [ 127 1526 201 ... | NORMAL | - | 5 |
| 21 | [1.2526357173919678, 0.7180959582328796, 0.415... | [0.4922509789466858, 0.7275816798210144, 0.859... | [0.8532054424285889, 0.7139995694160461, 0.696... | [0.6709019541740417, 0.7447842955589294, 0.779... | [72.2411801815033, 143.07429265975952, 214.991... | 0 | DENSE | 64 | Adam | 0.001 | - | 300 | - | 0.77 | 0 | 0 | LOWER | 400 | 10000 | True | [[1541 148 435 16 73]\n [ 74 1643 317 ... | NORMAL | - | 5 |
| 22 | [1.612979292869568, 1.6098767518997192, 1.6102... | [0.2574901878833771, 0.2004705816507339, 0.200... | [1.6093101501464844, 1.6092772483825684, 1.609... | [0.1923137307167053, 0.199215680360794, 0.1920... | [71.83844518661499, 143.0635724067688, 215.440... | 0 | DENSE | 64 | Adam | 0.001 | - | 300 | - | 0.21 | 0 | 0 | DEFAULT | 400 | 10000 | True | [[ 3 0 2210 0 0]\n [ 3 0 2200 ... | NORMAL | - | 5 |
| 23 | [1.6123055219650269, 1.4136512279510498, 0.892... | [0.1997019648551941, 0.339607834815979, 0.6405... | [1.6083794832229614, 1.10409414768219, 0.74592... | [0.1921568661928177, 0.5328627228736877, 0.719... | [72.79997491836548, 144.0263090133667, 215.028... | 0 | DENSE | 64 | Adam | 0.001 | - | 300 | - | 0.79 | 0 | 0 | LOWER_I | 400 | 10000 | True | [[1617 342 145 20 89]\n [ 121 1746 103 ... | NORMAL | - | 5 |
| 24 | [2.010152578353882, 1.5290956497192385, 1.2335... | [0.3290140032768249, 0.4641103744506836, 0.562... | [1.6453301906585691, 1.5359132289886477, 1.550... | [0.4279738664627075, 0.4795294106006622, 0.506... | [122.05491280555724, 241.92626881599423, 364.0... | 0 | DENSE | 64 | Adam | 0.001 | - | 300 | - | 0.49 | 0 | 0 | RAW | 200 | 10000 | True | [[1071 369 36 61 34 64 271 114 7... | NORMAL | - | 15 |
| 25 | [2.0807976722717285, 1.5866154432296753, 1.302... | [0.3101803958415985, 0.4424575269222259, 0.535... | [1.694566249847412, 1.5601444244384766, 1.5679... | [0.4060130715370178, 0.4670849740505218, 0.481... | [122.93492102622986, 244.07795214653012, 366.3... | 0 | DENSE | 64 | Adam | 0.001 | - | 300 | - | 0.47 | 0 | 0 | LOWER | 200 | 10000 | True | [[1157 287 34 93 38 72 240 96 18... | NORMAL | - | 15 |
| 26 | [2.22964859008789, 1.658429503440857, 1.336077... | [0.257966011762619, 0.4204676747322082, 0.5402... | [1.8096245527267456, 1.5306107997894287, 1.440... | [0.3592156767845154, 0.4859085083007812, 0.534... | [124.48556709289552, 248.0258502960205, 368.81... | 0 | DENSE | 64 | Adam | 0.001 | - | 300 | - | 0.54 | 0 | 0 | DEFAULT | 200 | 10000 | True | [[1180 105 23 97 50 109 413 26 7... | NORMAL | - | 15 |
| 27 | [1.9838707447052, 1.4141194820404053, 1.121265... | [0.3484601378440857, 0.5100711584091187, 0.607... | [1.531771898269653, 1.3959068059921265, 1.3805... | [0.4722614288330078, 0.5345359444618225, 0.558... | [122.99041819572447, 245.43910098075867, 367.5... | 0 | DENSE | 64 | Adam | 0.001 | - | 300 | - | 0.56 | 0 | 0 | LOWER_I | 200 | 10000 | True | [[1257 309 28 8 77 50 398 8 3... | NORMAL | - | 15 |
| 28 | [2.298656940460205, 1.906553149223328, 1.59620... | [0.2510797381401062, 0.3431517779827118, 0.438... | [2.042023181915283, 1.7355191707611084, 1.6539... | [0.3018562197685241, 0.4036078453063965, 0.439... | [219.676766872406, 433.9881939888001, 650.6269... | 0 | DENSE | 64 | Adam | 0.001 | - | 300 | - | 0.44 | 0 | 0 | RAW | 400 | 10000 | True | [[1050 246 42 204 46 94 260 77 4... | NORMAL | - | 15 |
| 29 | [2.215775489807129, 1.8288313150405884, 1.5736... | [0.2619398832321167, 0.3532491028308868, 0.434... | [1.9354660511016848, 1.716828465461731, 1.6661... | [0.327320247888565, 0.3944052159786224, 0.4274... | [212.37938380241397, 417.8751857280731, 627.87... | 0 | DENSE | 64 | Adam | 0.001 | - | 300 | - | 0.47 | 0 | 0 | LOWER | 400 | 10000 | True | [[1142 226 8 64 130 116 306 88 8... | NORMAL | - | 15 |
| 30 | [2.7099225521087646, 2.708319902420044, 2.7084... | [0.1086745113134384, 0.0673812627792358, 0.066... | [2.7085378170013428, 2.708480834960937, 2.7085... | [0.0658823549747467, 0.0657777786254882, 0.064... | [213.5495536327362, 428.6686849594116, 646.272... | 0 | DENSE | 64 | Adam | 0.001 | - | 300 | - | 0.43 | 0 | 0 | DEFAULT | 400 | 10000 | True | [[ 103 125 11 296 155 802 447 58 7... | NORMAL | - | 15 |
| 31 | [6687085056.0, 1187478372352.0, 1.320836067199... | [0.3819360733032226, 0.4748932421207428, 0.597... | [nan, 1.0665013790130615, 0.9180729389190674, ... | [0.294117659330368, 0.5433725714683533, 0.6208... | [1092.711680173874, 2239.074378967285, 3320.77... | 0 | Bidirectional GRU | 64 | Adam | 0.001 | - | 50 | - | 0.69 | 0 | 0 | RAW | 300 | 10000 | True | [[1428 125 512 34 114]\n [ 50 1417 393 ... | NORMAL | - | 5 |
| 32 | [2406436.75, 25465.07421875, nan, nan, nan] | [0.3964235186576843, 0.5264662504196167, 0.383... | [1.1757404804229736, 1.1074206829071045, nan, ... | [0.485647052526474, 0.5080784559249878, 0.1921... | [1368.7052409648895, 2801.06485581398, 4307.10... | 0 | Bidirectional GRU | 64 | Adam | 0.001 | - | 50 | - | 0.51 | 0 | 0 | LOWER | 300 | 10000 | True | [[1731 434 4 3 41]\n [ 601 1472 94 ... | NORMAL | - | 5 |
| 33 | [1.3197650909423828, 183.88555908203125, 0.674... | [0.3472784459590912, 0.5911459922790527, 0.748... | [1.1397396326065063, 0.7558140158653259, 0.713... | [0.4633725583553314, 0.7105882167816162, 0.735... | [1794.731845855713, 3350.3680169582367, 4541.5... | 0 | Bidirectional GRU | 64 | Adam | 0.001 | - | 50 | - | 0.74 | 0 | 0 | DEFAULT | 300 | 10000 | True | [[1087 167 730 68 161]\n [ 17 1742 242 ... | NORMAL | - | 5 |
| 34 | [278055.28125, 1.4219011068344116, 31504.09765... | [0.3278745114803314, 0.4618736505508423, 0.516... | [1.2132927179336548, 1.139415979385376, 1.1058... | [0.4410980343818664, 0.4923921525478363, 0.526... | [1202.219447374344, 2394.1133301258087, 3581.0... | 0 | Bidirectional GRU | 64 | Adam | 0.001 | - | 50 | - | 0.75 | 0 | 0 | LOWER_I | 300 | 10000 | True | [[1397 428 212 64 112]\n [ 143 1533 396 ... | NORMAL | - | 5 |
| 35 | [nan, nan, nan] | [0.0903417393565177, 0.0660798847675323, 0.066... | [nan, nan, nan] | [0.0684967339038848, 0.0684967339038848, 0.068... | [3581.0230338573456, 7243.392804384232, 13244.... | 0 | Bidirectional GRU | 64 | Adam | 0.001 | - | 50 | - | 0.07 | 0 | 0 | RAW | 300 | 10000 | True | [[2316 0 0 0 0 0 0 0 0... | NORMAL | - | 15 |
| 36 | [nan, nan, nan] | [0.0666928067803382, 0.0660798847675323, 0.066... | [nan, nan, nan] | [0.0684967339038848, 0.0684967339038848, 0.068... | [5272.877393007278, 10100.672457456589, 14649.... | 0 | Bidirectional GRU | 64 | Adam | 0.001 | - | 50 | - | 0.07 | 0 | 0 | LOWER | 300 | 10000 | True | [[2316 0 0 0 0 0 0 0 0... | NORMAL | - | 15 |
| 37 | [143974.984375, 1.981084942817688, nan, nan, nan] | [0.1826091557741165, 0.2948496639728546, 0.301... | [2.123115539550781, 1.8301618099212649, nan, n... | [0.2459607869386673, 0.3584313690662384, 0.068... | [4745.052874326706, 9517.231180667875, 14253.2... | 0 | Bidirectional GRU | 64 | Adam | 0.001 | - | 50 | - | 0.36 | 0 | 0 | DEFAULT | 300 | 10000 | True | [[ 552 57 29 485 1 132 234 68 4... | NORMAL | - | 15 |
| 38 | [nan, nan, nan] | [0.0665464028716087, 0.0660798847675323, 0.066... | [nan, nan, nan] | [0.0684967339038848, 0.0684967339038848, 0.068... | [4647.039727687836, 9398.7957239151, 14184.882... | 0 | Bidirectional GRU | 64 | Adam | 0.001 | - | 50 | - | 0.07 | 0 | 0 | LOWER_I | 300 | 10000 | True | [[2316 0 0 0 0 0 0 0 0... | NORMAL | - | 15 |
| 39 | [nan, nan, nan] | [0.1690849661827087, 0.2015163451433181, 0.201... | [nan, nan, nan] | [0.1921568661928177, 0.1921568661928177, 0.192... | [1732.056715965271, 3492.265405893326, 5230.93... | 0 | Bidirectional GRU | 64 | Adam | 0.001 | - | 300 | - | 0.20 | 0 | 0 | RAW | 300 | 10000 | True | [[2213 0 0 0 0]\n [2203 0 0 ... | NORMAL | - | 5 |
| 40 | [nan, nan, nan] | [0.3445647060871124, 0.2015163451433181, 0.201... | [nan, nan, nan] | [0.1921568661928177, 0.1921568661928177, 0.192... | [1745.630146026611, 3480.498616695404, 4916.81... | 0 | Bidirectional GRU | 64 | Adam | 0.001 | - | 300 | - | 0.20 | 0 | 0 | LOWER | 300 | 10000 | True | [[2213 0 0 0 0]\n [2203 0 0 ... | NORMAL | - | 5 |
| 41 | [9.50131130218506, 0.7876715660095215, 1.23750... | [0.3992627561092376, 0.6950762271881104, 0.808... | [0.9494749307632446, 0.6531278491020203, 0.620... | [0.6191372275352478, 0.7658039331436157, 0.788... | [1764.6829175949097, 3511.476901292801, 5241.2... | 0 | Bidirectional GRU | 64 | Adam | 0.001 | - | 300 | - | 0.80 | 0 | 0 | DEFAULT | 300 | 10000 | True | [[1597 259 167 86 104]\n [ 151 1747 64 ... | NORMAL | - | 5 |
| 42 | [117312464.0, nan, nan, nan] | [0.4742588102817535, 0.4018823504447937, 0.201... | [1.0666706562042236, nan, nan, nan] | [0.5543529391288757, 0.1921568661928177, 0.192... | [1787.1415581703186, 3542.0427582263947, 5307.... | 0 | Bidirectional GRU | 64 | Adam | 0.001 | - | 300 | - | 0.57 | 0 | 0 | LOWER_I | 300 | 10000 | True | [[1077 131 929 3 73]\n [ 132 1246 712 ... | NORMAL | - | 5 |
| 43 | [nan, nan, nan] | [0.1155686303973198, 0.0660798847675323, 0.066... | [nan, nan, nan] | [0.0684967339038848, 0.0684967339038848, 0.068... | [5241.320497512817, 10400.483037471771, 15546.... | 0 | Bidirectional GRU | 64 | Adam | 0.001 | - | 300 | - | 0.07 | 0 | 0 | RAW | 300 | 10000 | True | [[2316 0 0 0 0 0 0 0 0... | NORMAL | - | 15 |
| 44 | [4.139690399169922, 31051933696.0, 5946009.0, ... | [0.2126849740743637, 0.3071604967117309, 0.362... | [45.60334777832031, 1.8750523328781128, 1.7882... | [0.2989281117916107, 0.3549803793430328, 0.382... | [4035.43758225441, 7763.776393175125, 11458.88... | 0 | Bidirectional GRU | 64 | Adam | 0.001 | - | 300 | - | 0.56 | 0 | 0 | LOWER | 300 | 10000 | True | [[1488 176 59 4 126 19 158 42 5... | NORMAL | - | 15 |
| 45 | [3889620.75, 5.671187400817871, 174.6585388183... | [0.2653124034404754, 0.3765461146831512, 0.458... | [74537.96875, 1.6551882028579712, 3.5518321990... | [0.3310849666595459, 0.4476862847805023, 0.468... | [3747.2181475162506, 7466.645300388336, 11197.... | 0 | Bidirectional GRU | 64 | Adam | 0.001 | - | 300 | - | 0.57 | 0 | 0 | DEFAULT | 300 | 10000 | True | [[1448 185 41 29 71 61 236 37 5... | NORMAL | - | 15 |
| 46 | [47510450176.0, 10461987.0, 3.044379234313965,... | [0.2487633973360061, 0.3066957294940948, 0.376... | [2.28581562481608e+19, 1.265042553581863e+18, ... | [0.3026405274868011, 0.3697254955768585, 0.433... | [3202.14945268631, 6400.369203567505, 9596.970... | 0 | Bidirectional GRU | 64 | Adam | 0.001 | - | 300 | - | 0.43 | 0 | 0 | LOWER_I | 300 | 10000 | True | [[ 918 165 76 21 2 162 512 56 9... | NORMAL | - | 15 |
| 47 | [1.721323847770691, 1.6283942461013794, 1.6182... | [0.167633980512619, 0.200139433145523, 0.20036... | [1.6598328351974487, 1.6358952522277832, 1.622... | [0.1987451016902923, 0.1921568661928177, 0.192... | [12731.715883731842, 25516.930138111115, 38309... | 0 | Transformer | 128 | Adam | 0.001 | - | - | - | 0.20 | 0 | 0 | LOWER_I | 300 | - | True | [[2213 0 0 0 0]\n [2203 0 0 ... | TL | distilbert-base-uncased | 5 |
| 48 | [0.6788761615753174, 0.2942142188549042, 0.161... | [0.6950744986534119, 0.8956339955329895, 0.945... | [0.4689745604991913, 0.3302323520183563, 0.374... | [0.8304314017295837, 0.8854901790618896, 0.882... | [12792.493296146393, 25582.148589611053, 38391... | 0 | Transformer | 128 | Adam | 5e-05 | - | - | - | 0.88 | 0 | 0 | LOWER_I | 300 | - | True | [[1932 65 99 57 60]\n [ 96 1780 103 ... | TL | distilbert-base-uncased | 5 |
| 49 | [1.720706582069397, 1.62993323802948, 1.619087... | [0.2695686221122741, 0.1993899792432785, 0.199... | [1.612503170967102, 1.6107159852981567, 1.6184... | [0.2062745094299316, 0.2043921500444412, 0.198... | [25539.273556947708, 51105.00569915772, 76684.... | 0 | Transformer | 128 | Adam | 0.001 | - | - | - | 0.20 | 0 | 0 | LOWER_I | 300 | - | True | [[ 0 0 0 0 2213]\n [ 0 0 0 ... | TL | bert-base-uncased | 5 |
| 50 | [0.6129876375198364, 0.2475549280643463, 0.137... | [0.7192941308021545, 0.9139520525932312, 0.954... | [0.3863182961940765, 0.3083285987377167, 0.356... | [0.8680784106254578, 0.8953725695610046, 0.892... | [25577.77806091309, 51058.13135123253, 76546.3... | 0 | Transformer | 128 | Adam | 5e-05 | - | - | - | 0.89 | 0 | 0 | LOWER_I | 300 | - | True | [[1845 72 102 118 76]\n [ 34 1847 67 ... | TL | bert-base-uncased | 5 |
| 51 | [1.7034351825714111, 1.6270469427108765, 1.617... | [0.2704313695430755, 0.198047935962677, 0.1987... | [1.6337881088256836, 1.6143453121185305, 1.611... | [0.2062745094299316, 0.2043921500444412, 0.198... | [8417.139698982239, 16806.370790719986, 25206.... | 0 | Transformer | 128 | Adam | 0.001 | - | - | - | 0.20 | 0 | 0 | LOWER_I | 300 | - | True | [[ 0 0 0 0 2213]\n [ 0 0 0 ... | TL | google/electra-small-discriminator | 5 |
| 52 | [1.167790174484253, 0.668626070022583, 0.48968... | [0.5210353136062622, 0.760610044002533, 0.8271... | [0.8901610970497131, 0.8246893286705017, 0.576... | [0.6842352747917175, 0.7243921756744385, 0.799... | [8379.385590314865, 16772.64554834366, 25164.1... | 0 | Transformer | 128 | Adam | 5e-05 | - | - | - | 0.81 | 0 | 0 | LOWER_I | 300 | - | True | [[1545 117 274 131 146]\n [ 58 1393 243 ... | TL | google/electra-small-discriminator | 5 |
df['CalculationTime'] = list(map(lambda x: round(np.sum(x), 2), df.time.values))
Jak již bylo zmíněno datová sada byla vytvořena vlastními silami. Byl využit Projekt Gutenberg, který obsahuje umělecká díla. Z toho projektu pomocí R skriptu byly separovány díla s anglickým textem a zároveň taková, která obsahovala jasně specifikovaného autora. Tento fakt nám umožnil vytvořit finální datovou sadu. Tato datová sada textový řetězec reprezentovaný velikosti o n větách a identifikátor, který reprezentoval autora díla.
S touto datovou sadou bylo pracováno v diplomové práci, přičemž tento jupyter vznikl akorát pro tento projekt. Jupyter byl extra vytvořen s použitím části kódu z diplomové práce.
Datová sada obsahovala identifikátory pro autory, přičemž zde byly využity datové sady o 5 a 15 autorech. Tyto identifikátory byly transformovaný pomocí LabelEncoderu do žádaného prostoru. Respektive čísel 0-5, 0-15. K samotnému určení, kdo umělecké dílo napsal byla využita klasifikace. Chybou proto byla využita SparseCategoricalCrossentropy, abychom nemuseli target vektory transformovat to one hot vektorů.
Vytvořená datová sada neobsahovala stejný počet záznamů pro každého autora, a proto bylo přistoupeno k normalizaci na určitou hodnotu vzhledem k nejslabšímu autorovi. Pro představu zde v projektu byla využita normalizační hodnota 15 000. Takže ke každému autorovi bylo v datové sadě ponecháno 15 000 záznamů, které byly náhodně vybrány ze všech. Díky tohoto přístupu jsme zároveň mohli využít metriku, která bude zmíněna níže.
Jako metrika byla využita přesnost. Toto jsme si mohli dovolit, díky toho že datová sada obsahovala stejný počet záznamů, pro každého autora. Číselná hodnota nám pak říkala, s jakou přesnosti jsme schopni určit autora vzhledem k autorům s kterými jsme pracovali.
V rámci předzpracování dat bylo experimentováno s 4 přístupu:
Surová data.
Data upravena pomocí gensim metody na předzpracování.
Data transformována do malých písmen.
Data transformovaná do malých písmen společně s malým předzpracováním.
Účelem bylo porovnat jaký vliv má předzpracování na datovou sadu.
Aplikací předzpracování vždy ztrácíme určitou informaci. Většinou bývají textová data na tuto operaci náchylná a je nutné provádět předzpracování s rozvahou.
Po předzpracování můžeme extrahovat relevantní informace, zmenšit vstupní vektor záznamu, získat základní tvar slova, a tak využít již předučený číselný vektor a mnoho dalšího.
Bylo provedeno 53 experimentů, které budou zpracovány. Jak již bylo zmíněno experimenty byly provedeny po 4 různých přístupech předzpracování. Zároveň byly využity různé přístupy a to:
V rámci těchto tříd bylo provedeno postupně:
Jednoduchý model a rekurentní model pracoval s různou velikostí vstupu a zároveň experimentoval s velikostí embeddingu, který byl učen jakožto reprezentace jednotlivých tokenů. Zároveň pro tyto modely byly testovány obě množiny s 5 a 15 autory.
Složitější Transformer model byl bohužel zpracován pouze v podobě 5 autorů a jednom druhu předzpracování. Za to byly využity 3 druhy Transformeru, abychom mohli porovnat jaký model si vede nejlépe. K tomuto porovnání byly vybrány modely:
Více specifický popis bude vytvořen u grafů, které budou vizualizovat dosažené výsledky.
dense_df = df[df.ModelName == "DENSE"]
dense_df
| loss | accuracy | val_loss | val_accuracy | time | NaN | ModelName | BatchSize | Optimizer | LR | Epochs | EmbeddingSize | Time | Accuracy | Hits | Miss | Key | SeqLen | VocabSize | TrainableEmbedding | ConfMatrix | Type | TransformerName | NumberOfAuthors | CalculationTime | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | [1.1091737747192385, 0.545360803604126, 0.3251... | [0.545690655708313, 0.8149368166923523, 0.8950... | [0.6624863743782043, 0.658841609954834, 0.7480... | [0.7519999742507935, 0.774117648601532, 0.7775... | [19.60962748527527, 38.74808216094971, 57.6085... | 0 | DENSE | 64 | Adam | 0.001 | - | 50 | - | 0.77 | 0 | 0 | RAW | 200 | 10000 | True | [[1678 158 256 36 85]\n [ 172 1459 199 ... | NORMAL | - | 5 | 288.14 |
| 1 | [1.1327065229415894, 0.5607202053070068, 0.323... | [0.5527529120445251, 0.807494580745697, 0.8943... | [0.67454594373703, 0.6359243392944336, 0.74231... | [0.7509019374847412, 0.779764711856842, 0.7863... | [20.090360403060917, 39.35527753829956, 58.566... | 0 | DENSE | 64 | Adam | 0.001 | - | 50 | - | 0.77 | 0 | 0 | LOWER | 200 | 10000 | True | [[1488 306 238 45 136]\n [ 85 1690 127 ... | NORMAL | - | 5 | 292.77 |
| 2 | [1.225493311882019, 0.6993909478187561, 0.4104... | [0.4702745079994201, 0.7103093862533569, 0.854... | [0.8190874457359314, 0.6555845737457275, 0.699... | [0.6439215540885925, 0.7667450904846191, 0.792... | [18.983893871307373, 37.27286505699158, 55.811... | 0 | DENSE | 64 | Adam | 0.001 | - | 50 | - | 0.77 | 0 | 0 | DEFAULT | 200 | 10000 | True | [[1451 183 376 77 126]\n [ 280 1587 93 ... | NORMAL | - | 5 | 278.50 |
| 3 | [1.1457595825195312, 0.5117649435997009, 0.299... | [0.5310588479042053, 0.8236165642738342, 0.901... | [0.6402604579925537, 0.5437538027763367, 0.646... | [0.7640784382820129, 0.8133333325386047, 0.811... | [19.854929447174072, 38.84715795516968, 58.101... | 0 | DENSE | 64 | Adam | 0.001 | - | 50 | - | 0.81 | 0 | 0 | LOWER_I | 200 | 10000 | True | [[1670 150 205 51 137]\n [ 198 1715 102 ... | NORMAL | - | 5 | 290.37 |
| 4 | [1.2395521402359009, 0.6198447942733765, 0.376... | [0.4878901839256286, 0.7835119962692261, 0.877... | [0.7236012816429138, 0.5898836255073547, 0.644... | [0.7349019646644592, 0.7899608016014099, 0.790... | [24.031248092651367, 47.32864117622376, 70.390... | 0 | DENSE | 64 | Adam | 0.001 | - | 50 | - | 0.79 | 0 | 0 | RAW | 400 | 10000 | True | [[1691 143 276 33 70]\n [ 141 1608 183 ... | NORMAL | - | 5 | 352.12 |
| 5 | [1.256301760673523, 0.6661115288734436, 0.3901... | [0.4733960926532745, 0.7623529434204102, 0.872... | [0.8248462677001953, 0.6136951446533203, 0.676... | [0.6912941336631775, 0.7836862802505493, 0.781... | [23.661153078079224, 46.77628660202026, 69.989... | 0 | DENSE | 64 | Adam | 0.001 | - | 50 | - | 0.77 | 0 | 0 | LOWER | 400 | 10000 | True | [[1750 172 193 31 67]\n [ 198 1542 185 ... | NORMAL | - | 5 | 347.99 |
| 6 | [1.4970093965530396, 0.8270591497421265, 0.494... | [0.3286588191986084, 0.6616818904876709, 0.820... | [1.0638797283172607, 0.6806007027626038, 0.622... | [0.562666654586792, 0.7440000176429749, 0.7909... | [22.825047254562374, 44.74828147888184, 66.949... | 0 | DENSE | 64 | Adam | 0.001 | - | 50 | - | 0.79 | 0 | 0 | DEFAULT | 400 | 10000 | True | [[1574 149 312 64 114]\n [ 230 1586 166 ... | NORMAL | - | 5 | 467.94 |
| 7 | [1.224252700805664, 0.5908034443855286, 0.3403... | [0.4824941158294678, 0.7896470427513123, 0.886... | [0.7540565133094788, 0.5498793125152588, 0.583... | [0.7225098013877869, 0.8059607744216919, 0.817... | [23.619723081588745, 46.35664224624634, 68.851... | 0 | DENSE | 64 | Adam | 0.001 | - | 50 | - | 0.80 | 0 | 0 | LOWER_I | 400 | 10000 | True | [[1559 158 332 46 118]\n [ 127 1619 214 ... | NORMAL | - | 5 | 344.16 |
| 8 | [2.038879871368408, 1.563637614250183, 1.32510... | [0.3195910453796386, 0.4616528749465942, 0.546... | [1.6314889192581177, 1.4994858503341677, 1.474... | [0.4371764659881592, 0.4931764602661133, 0.517... | [58.19883918762207, 116.40550565719604, 174.34... | 0 | DENSE | 64 | Adam | 0.001 | - | 50 | - | 0.53 | 0 | 0 | RAW | 200 | 10000 | True | [[1382 246 16 48 47 53 275 73 19... | NORMAL | - | 15 | 1216.27 |
| 9 | [2.007798671722412, 1.554110050201416, 1.32113... | [0.3344993591308594, 0.4640406668186188, 0.547... | [1.6179845333099363, 1.4908241033554075, 1.485... | [0.4401045739650726, 0.4930196106433868, 0.519... | [57.74453091621399, 115.19541382789612, 171.85... | 0 | DENSE | 64 | Adam | 0.001 | - | 50 | - | 0.53 | 0 | 0 | LOWER | 200 | 10000 | True | [[1182 218 46 45 52 60 311 104 10... | NORMAL | - | 15 | 1197.67 |
| 10 | [2.0945076942443848, 1.5407600402832031, 1.260... | [0.3058666586875915, 0.4837879538536072, 0.582... | [1.6695539951324463, 1.3899223804473877, 1.346... | [0.4403137266635895, 0.5398169755935669, 0.576... | [55.19337034225464, 109.7722339630127, 164.612... | 0 | DENSE | 64 | Adam | 0.001 | - | 50 | - | 0.58 | 0 | 0 | DEFAULT | 200 | 10000 | True | [[1521 213 24 13 44 56 146 17 9... | NORMAL | - | 15 | 1143.94 |
| 11 | [2.0426511764526367, 1.4889464378356934, 1.221... | [0.331916332244873, 0.4952156841754913, 0.5902... | [1.6068023443222046, 1.3842854499816897, 1.342... | [0.4554771184921264, 0.5311895608901978, 0.570... | [57.66135025024414, 114.47401738166808, 171.94... | 0 | DENSE | 64 | Adam | 0.001 | - | 50 | - | 0.58 | 0 | 0 | LOWER_I | 200 | 10000 | True | [[1535 153 14 8 31 47 290 53 11... | NORMAL | - | 15 | 1197.03 |
| 12 | [2.0767838954925537, 1.5745370388031006, 1.335... | [0.3215215802192688, 0.4560929536819458, 0.542... | [1.649091720581055, 1.498886227607727, 1.46961... | [0.4290196001529693, 0.4890457391738891, 0.523... | [68.15565466880798, 133.44615292549133, 200.59... | 0 | DENSE | 64 | Adam | 0.001 | - | 50 | - | 0.53 | 0 | 0 | RAW | 400 | 10000 | True | [[1258 179 48 47 35 38 323 37 10... | NORMAL | - | 15 | 1409.46 |
| 13 | [2.10368275642395, 1.6013576984405518, 1.34118... | [0.3073516488075256, 0.4520435631275177, 0.544... | [1.697721242904663, 1.4982041120529177, 1.4765... | [0.4145882427692413, 0.4970980286598205, 0.525... | [69.45309066772461, 136.9076645374298, 204.584... | 0 | DENSE | 64 | Adam | 0.001 | - | 50 | - | 0.53 | 0 | 0 | LOWER | 400 | 10000 | True | [[1307 134 32 37 30 72 318 55 22... | NORMAL | - | 15 | 1433.87 |
| 14 | [2.148930549621582, 1.7112888097763062, 1.4490... | [0.2799529433250427, 0.3945388495922088, 0.491... | [1.811055302619934, 1.5882717370986938, 1.4911... | [0.3535163402557373, 0.4577777683734894, 0.499... | [64.82900762557983, 130.9426231384277, 195.940... | 0 | DENSE | 64 | Adam | 0.001 | - | 50 | - | 0.55 | 0 | 0 | DEFAULT | 400 | 10000 | True | [[1439 88 23 22 27 74 353 36 18... | NORMAL | - | 15 | 2350.03 |
| 15 | [2.1052839756011963, 1.5557817220687866, 1.270... | [0.3025882244110107, 0.4585562944412231, 0.561... | [1.6754242181777954, 1.404281497001648, 1.3419... | [0.414274513721466, 0.5210980176925659, 0.5575... | [68.50114560127258, 136.1558701992035, 204.361... | 0 | DENSE | 64 | Adam | 0.001 | - | 50 | - | 0.56 | 0 | 0 | LOWER_I | 400 | 10000 | True | [[1434 184 20 21 17 45 317 107 14... | NORMAL | - | 15 | 1429.72 |
| 16 | [1.2060670852661133, 0.5876715183258057, 0.301... | [0.5213333368301392, 0.7960087060928345, 0.901... | [0.7380061745643616, 0.6602489352226257, 0.777... | [0.722196102142334, 0.7658039331436157, 0.7722... | [42.88776779174805, 84.30938196182251, 125.639... | 0 | DENSE | 64 | Adam | 0.001 | - | 300 | - | 0.76 | 0 | 0 | RAW | 200 | 10000 | True | [[1499 249 353 35 77]\n [ 109 1561 244 ... | NORMAL | - | 5 | 627.37 |
| 17 | [1.1218571662902832, 0.5309804677963257, 0.279... | [0.5644392371177673, 0.8198866844177246, 0.911... | [0.6812100410461426, 0.6584259867668152, 0.761... | [0.7573333382606506, 0.7799215912818909, 0.788... | [42.935046672821045, 84.6618320941925, 126.400... | 0 | DENSE | 64 | Adam | 0.001 | - | 300 | - | 0.76 | 0 | 0 | LOWER | 200 | 10000 | True | [[1526 250 228 57 152]\n [ 111 1533 188 ... | NORMAL | - | 5 | 630.26 |
| 18 | [1.2983185052871704, 0.67186439037323, 0.35850... | [0.4403764605522156, 0.7503268122673035, 0.875... | [0.8450406789779663, 0.6062559485435486, 0.653... | [0.6599215865135193, 0.7830588221549988, 0.803... | [41.16602444648743, 81.34677410125732, 120.331... | 0 | DENSE | 64 | Adam | 0.001 | - | 300 | - | 0.79 | 0 | 0 | DEFAULT | 200 | 10000 | True | [[1461 291 227 76 158]\n [ 174 1726 94 ... | NORMAL | - | 5 | 604.90 |
| 19 | [1.1131271123886108, 0.4970025420188904, 0.263... | [0.5663529634475708, 0.8301350474357605, 0.912... | [0.634172260761261, 0.5626846551895142, 0.6592... | [0.7694117426872253, 0.8014117479324341, 0.809... | [42.32026171684265, 83.03080439567566, 123.852... | 0 | DENSE | 64 | Adam | 0.001 | - | 300 | - | 0.80 | 0 | 0 | LOWER_I | 200 | 10000 | True | [[1508 290 287 32 96]\n [ 77 1770 190 ... | NORMAL | - | 5 | 618.58 |
| 20 | [1.4313589334487915, 0.976482629776001, 0.5856... | [0.3718431293964386, 0.5906579494476318, 0.785... | [1.1860140562057495, 0.7914318442344666, 0.690... | [0.492078423500061, 0.7022745013237, 0.7548235... | [75.83154845237732, 148.4723603725433, 221.875... | 0 | DENSE | 64 | Adam | 0.001 | - | 300 | - | 0.76 | 0 | 0 | RAW | 400 | 10000 | True | [[1521 245 329 36 82]\n [ 127 1526 201 ... | NORMAL | - | 5 | 1535.43 |
| 21 | [1.2526357173919678, 0.7180959582328796, 0.415... | [0.4922509789466858, 0.7275816798210144, 0.859... | [0.8532054424285889, 0.7139995694160461, 0.696... | [0.6709019541740417, 0.7447842955589294, 0.779... | [72.2411801815033, 143.07429265975952, 214.991... | 0 | DENSE | 64 | Adam | 0.001 | - | 300 | - | 0.77 | 0 | 0 | LOWER | 400 | 10000 | True | [[1541 148 435 16 73]\n [ 74 1643 317 ... | NORMAL | - | 5 | 1503.40 |
| 22 | [1.612979292869568, 1.6098767518997192, 1.6102... | [0.2574901878833771, 0.2004705816507339, 0.200... | [1.6093101501464844, 1.6092772483825684, 1.609... | [0.1923137307167053, 0.199215680360794, 0.1920... | [71.83844518661499, 143.0635724067688, 215.440... | 0 | DENSE | 64 | Adam | 0.001 | - | 300 | - | 0.21 | 0 | 0 | DEFAULT | 400 | 10000 | True | [[ 3 0 2210 0 0]\n [ 3 0 2200 ... | NORMAL | - | 5 | 1079.05 |
| 23 | [1.6123055219650269, 1.4136512279510498, 0.892... | [0.1997019648551941, 0.339607834815979, 0.6405... | [1.6083794832229614, 1.10409414768219, 0.74592... | [0.1921568661928177, 0.5328627228736877, 0.719... | [72.79997491836548, 144.0263090133667, 215.028... | 0 | DENSE | 64 | Adam | 0.001 | - | 300 | - | 0.79 | 0 | 0 | LOWER_I | 400 | 10000 | True | [[1617 342 145 20 89]\n [ 121 1746 103 ... | NORMAL | - | 5 | 2015.48 |
| 24 | [2.010152578353882, 1.5290956497192385, 1.2335... | [0.3290140032768249, 0.4641103744506836, 0.562... | [1.6453301906585691, 1.5359132289886477, 1.550... | [0.4279738664627075, 0.4795294106006622, 0.506... | [122.05491280555724, 241.92626881599423, 364.0... | 0 | DENSE | 64 | Adam | 0.001 | - | 300 | - | 0.49 | 0 | 0 | RAW | 200 | 10000 | True | [[1071 369 36 61 34 64 271 114 7... | NORMAL | - | 15 | 1822.20 |
| 25 | [2.0807976722717285, 1.5866154432296753, 1.302... | [0.3101803958415985, 0.4424575269222259, 0.535... | [1.694566249847412, 1.5601444244384766, 1.5679... | [0.4060130715370178, 0.4670849740505218, 0.481... | [122.93492102622986, 244.07795214653012, 366.3... | 0 | DENSE | 64 | Adam | 0.001 | - | 300 | - | 0.47 | 0 | 0 | LOWER | 200 | 10000 | True | [[1157 287 34 93 38 72 240 96 18... | NORMAL | - | 15 | 1827.86 |
| 26 | [2.22964859008789, 1.658429503440857, 1.336077... | [0.257966011762619, 0.4204676747322082, 0.5402... | [1.8096245527267456, 1.5306107997894287, 1.440... | [0.3592156767845154, 0.4859085083007812, 0.534... | [124.48556709289552, 248.0258502960205, 368.81... | 0 | DENSE | 64 | Adam | 0.001 | - | 300 | - | 0.54 | 0 | 0 | DEFAULT | 200 | 10000 | True | [[1180 105 23 97 50 109 413 26 7... | NORMAL | - | 15 | 2570.65 |
| 27 | [1.9838707447052, 1.4141194820404053, 1.121265... | [0.3484601378440857, 0.5100711584091187, 0.607... | [1.531771898269653, 1.3959068059921265, 1.3805... | [0.4722614288330078, 0.5345359444618225, 0.558... | [122.99041819572447, 245.43910098075867, 367.5... | 0 | DENSE | 64 | Adam | 0.001 | - | 300 | - | 0.56 | 0 | 0 | LOWER_I | 200 | 10000 | True | [[1257 309 28 8 77 50 398 8 3... | NORMAL | - | 15 | 2569.64 |
| 28 | [2.298656940460205, 1.906553149223328, 1.59620... | [0.2510797381401062, 0.3431517779827118, 0.438... | [2.042023181915283, 1.7355191707611084, 1.6539... | [0.3018562197685241, 0.4036078453063965, 0.439... | [219.676766872406, 433.9881939888001, 650.6269... | 0 | DENSE | 64 | Adam | 0.001 | - | 300 | - | 0.44 | 0 | 0 | RAW | 400 | 10000 | True | [[1050 246 42 204 46 94 260 77 4... | NORMAL | - | 15 | 4540.63 |
| 29 | [2.215775489807129, 1.8288313150405884, 1.5736... | [0.2619398832321167, 0.3532491028308868, 0.434... | [1.9354660511016848, 1.716828465461731, 1.6661... | [0.327320247888565, 0.3944052159786224, 0.4274... | [212.37938380241397, 417.8751857280731, 627.87... | 0 | DENSE | 64 | Adam | 0.001 | - | 300 | - | 0.47 | 0 | 0 | LOWER | 400 | 10000 | True | [[1142 226 8 64 130 116 306 88 8... | NORMAL | - | 15 | 5882.86 |
| 30 | [2.7099225521087646, 2.708319902420044, 2.7084... | [0.1086745113134384, 0.0673812627792358, 0.066... | [2.7085378170013428, 2.708480834960937, 2.7085... | [0.0658823549747467, 0.0657777786254882, 0.064... | [213.5495536327362, 428.6686849594116, 646.272... | 0 | DENSE | 64 | Adam | 0.001 | - | 300 | - | 0.43 | 0 | 0 | DEFAULT | 400 | 10000 | True | [[ 103 125 11 296 155 802 447 58 7... | NORMAL | - | 15 | 11962.07 |
px.bar(
dense_df,
x='Key',
y="Accuracy",
color='EmbeddingSize',
barmode="group",
facet_col="SeqLen",
facet_row="NumberOfAuthors",
text=dense_df.Accuracy,
title="Výsledky klasifikace autorů v závislosti na proměnných"
)
Jak lze pozorovat na výsledcích bylo experimentováno s velikostí vektoru, kterým bude reprezentováno slovo. Přičemž rozdíl mezi velikosti 50 a 300 nebyl markantní. Výsledky byly na všech typech předzpracovaných datech podobné. Můžeme s určitou jistotou říct, že pro tuto neuronovou síť stačí takhle velký číselný vektor, aby zachytil většinu relevantních vlastností slova. Velikost nad 50 nepřispěje k zlepšení přesnosti, akorát zvýší složitost neuronové sítě, neboli rozšíří počet parametrů, které bude potřeba naučit interně v neuronové síti, aby kvalitně bylo schopna předpovědět daného autora. U 15 autorů, pak lze pozorovat, že menší embedding hrál určitou roli v přesnosti.
Velikost vstupní sekvence jednou byla omezena na průměrnou velikost délky záznamu a v druhém případě na maximálně, takže žádná nebyla oříznuta. Spíše docházelo k operaci paddingu. U 15 autorů lze pozorovat jisté zlepšení u kratší sekvence.
Pro tento jednoduchý model lze pozorovat, že nejlepších výsledků dosahovalo předzpracování z knihovny gensim a slabší předzpracování, které transformovalo do malých písmen, odstranilo interpunkci, minimalizovalo počet bílých znaků. Důvodem může být, že tento model lépe pracoval právě s klíčovými slovy, která jsou maximálně relevantní. Šum bohužel zmenšoval přesnost modelu o jednotky procent.
Zároveň lze pozorovat, že 15 autorů pro síť bylo komplikované rozeznat. Přesnost kolem 60 % není nijak extra zajímavá. Narozdíl s 5 autory si síť dokázala poradit s přesností skoro 80 %.
rnn_df = df[df.ModelName == "Bidirectional GRU"]
rnn_df
| loss | accuracy | val_loss | val_accuracy | time | NaN | ModelName | BatchSize | Optimizer | LR | Epochs | EmbeddingSize | Time | Accuracy | Hits | Miss | Key | SeqLen | VocabSize | TrainableEmbedding | ConfMatrix | Type | TransformerName | NumberOfAuthors | CalculationTime | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 31 | [6687085056.0, 1187478372352.0, 1.320836067199... | [0.3819360733032226, 0.4748932421207428, 0.597... | [nan, 1.0665013790130615, 0.9180729389190674, ... | [0.294117659330368, 0.5433725714683533, 0.6208... | [1092.711680173874, 2239.074378967285, 3320.77... | 0 | Bidirectional GRU | 64 | Adam | 0.001 | - | 50 | - | 0.69 | 0 | 0 | RAW | 300 | 10000 | True | [[1428 125 512 34 114]\n [ 50 1417 393 ... | NORMAL | - | 5 | 62433.80 |
| 32 | [2406436.75, 25465.07421875, nan, nan, nan] | [0.3964235186576843, 0.5264662504196167, 0.383... | [1.1757404804229736, 1.1074206829071045, nan, ... | [0.485647052526474, 0.5080784559249878, 0.1921... | [1368.7052409648895, 2801.06485581398, 4307.10... | 0 | Bidirectional GRU | 64 | Adam | 0.001 | - | 50 | - | 0.51 | 0 | 0 | LOWER | 300 | 10000 | True | [[1731 434 4 3 41]\n [ 601 1472 94 ... | NORMAL | - | 5 | 21699.77 |
| 33 | [1.3197650909423828, 183.88555908203125, 0.674... | [0.3472784459590912, 0.5911459922790527, 0.748... | [1.1397396326065063, 0.7558140158653259, 0.713... | [0.4633725583553314, 0.7105882167816162, 0.735... | [1794.731845855713, 3350.3680169582367, 4541.5... | 0 | Bidirectional GRU | 64 | Adam | 0.001 | - | 50 | - | 0.74 | 0 | 0 | DEFAULT | 300 | 10000 | True | [[1087 167 730 68 161]\n [ 17 1742 242 ... | NORMAL | - | 5 | 30416.04 |
| 34 | [278055.28125, 1.4219011068344116, 31504.09765... | [0.3278745114803314, 0.4618736505508423, 0.516... | [1.2132927179336548, 1.139415979385376, 1.1058... | [0.4410980343818664, 0.4923921525478363, 0.526... | [1202.219447374344, 2394.1133301258087, 3581.0... | 0 | Bidirectional GRU | 64 | Adam | 0.001 | - | 50 | - | 0.75 | 0 | 0 | LOWER_I | 300 | 10000 | True | [[1397 428 212 64 112]\n [ 143 1533 396 ... | NORMAL | - | 5 | 65487.81 |
| 35 | [nan, nan, nan] | [0.0903417393565177, 0.0660798847675323, 0.066... | [nan, nan, nan] | [0.0684967339038848, 0.0684967339038848, 0.068... | [3581.0230338573456, 7243.392804384232, 13244.... | 0 | Bidirectional GRU | 64 | Adam | 0.001 | - | 50 | - | 0.07 | 0 | 0 | RAW | 300 | 10000 | True | [[2316 0 0 0 0 0 0 0 0... | NORMAL | - | 15 | 24069.29 |
| 36 | [nan, nan, nan] | [0.0666928067803382, 0.0660798847675323, 0.066... | [nan, nan, nan] | [0.0684967339038848, 0.0684967339038848, 0.068... | [5272.877393007278, 10100.672457456589, 14649.... | 0 | Bidirectional GRU | 64 | Adam | 0.001 | - | 50 | - | 0.07 | 0 | 0 | LOWER | 300 | 10000 | True | [[2316 0 0 0 0 0 0 0 0... | NORMAL | - | 15 | 30023.22 |
| 37 | [143974.984375, 1.981084942817688, nan, nan, nan] | [0.1826091557741165, 0.2948496639728546, 0.301... | [2.123115539550781, 1.8301618099212649, nan, n... | [0.2459607869386673, 0.3584313690662384, 0.068... | [4745.052874326706, 9517.231180667875, 14253.2... | 0 | Bidirectional GRU | 64 | Adam | 0.001 | - | 50 | - | 0.36 | 0 | 0 | DEFAULT | 300 | 10000 | True | [[ 552 57 29 485 1 132 234 68 4... | NORMAL | - | 15 | 70875.74 |
| 38 | [nan, nan, nan] | [0.0665464028716087, 0.0660798847675323, 0.066... | [nan, nan, nan] | [0.0684967339038848, 0.0684967339038848, 0.068... | [4647.039727687836, 9398.7957239151, 14184.882... | 0 | Bidirectional GRU | 64 | Adam | 0.001 | - | 50 | - | 0.07 | 0 | 0 | LOWER_I | 300 | 10000 | True | [[2316 0 0 0 0 0 0 0 0... | NORMAL | - | 15 | 28230.72 |
| 39 | [nan, nan, nan] | [0.1690849661827087, 0.2015163451433181, 0.201... | [nan, nan, nan] | [0.1921568661928177, 0.1921568661928177, 0.192... | [1732.056715965271, 3492.265405893326, 5230.93... | 0 | Bidirectional GRU | 64 | Adam | 0.001 | - | 300 | - | 0.20 | 0 | 0 | RAW | 300 | 10000 | True | [[2213 0 0 0 0]\n [2203 0 0 ... | NORMAL | - | 5 | 10455.26 |
| 40 | [nan, nan, nan] | [0.3445647060871124, 0.2015163451433181, 0.201... | [nan, nan, nan] | [0.1921568661928177, 0.1921568661928177, 0.192... | [1745.630146026611, 3480.498616695404, 4916.81... | 0 | Bidirectional GRU | 64 | Adam | 0.001 | - | 300 | - | 0.20 | 0 | 0 | LOWER | 300 | 10000 | True | [[2213 0 0 0 0]\n [2203 0 0 ... | NORMAL | - | 5 | 10142.94 |
| 41 | [9.50131130218506, 0.7876715660095215, 1.23750... | [0.3992627561092376, 0.6950762271881104, 0.808... | [0.9494749307632446, 0.6531278491020203, 0.620... | [0.6191372275352478, 0.7658039331436157, 0.788... | [1764.6829175949097, 3511.476901292801, 5241.2... | 0 | Bidirectional GRU | 64 | Adam | 0.001 | - | 300 | - | 0.80 | 0 | 0 | DEFAULT | 300 | 10000 | True | [[1597 259 167 86 104]\n [ 151 1747 64 ... | NORMAL | - | 5 | 48701.47 |
| 42 | [117312464.0, nan, nan, nan] | [0.4742588102817535, 0.4018823504447937, 0.201... | [1.0666706562042236, nan, nan, nan] | [0.5543529391288757, 0.1921568661928177, 0.192... | [1787.1415581703186, 3542.0427582263947, 5307.... | 0 | Bidirectional GRU | 64 | Adam | 0.001 | - | 300 | - | 0.57 | 0 | 0 | LOWER_I | 300 | 10000 | True | [[1077 131 929 3 73]\n [ 132 1246 712 ... | NORMAL | - | 5 | 17682.81 |
| 43 | [nan, nan, nan] | [0.1155686303973198, 0.0660798847675323, 0.066... | [nan, nan, nan] | [0.0684967339038848, 0.0684967339038848, 0.068... | [5241.320497512817, 10400.483037471771, 15546.... | 0 | Bidirectional GRU | 64 | Adam | 0.001 | - | 300 | - | 0.07 | 0 | 0 | RAW | 300 | 10000 | True | [[2316 0 0 0 0 0 0 0 0... | NORMAL | - | 15 | 31188.12 |
| 44 | [4.139690399169922, 31051933696.0, 5946009.0, ... | [0.2126849740743637, 0.3071604967117309, 0.362... | [45.60334777832031, 1.8750523328781128, 1.7882... | [0.2989281117916107, 0.3549803793430328, 0.382... | [4035.43758225441, 7763.776393175125, 11458.88... | 0 | Bidirectional GRU | 64 | Adam | 0.001 | - | 300 | - | 0.56 | 0 | 0 | LOWER | 300 | 10000 | True | [[1488 176 59 4 126 19 158 42 5... | NORMAL | - | 15 | 207936.63 |
| 45 | [3889620.75, 5.671187400817871, 174.6585388183... | [0.2653124034404754, 0.3765461146831512, 0.458... | [74537.96875, 1.6551882028579712, 3.5518321990... | [0.3310849666595459, 0.4476862847805023, 0.468... | [3747.2181475162506, 7466.645300388336, 11197.... | 0 | Bidirectional GRU | 64 | Adam | 0.001 | - | 300 | - | 0.57 | 0 | 0 | DEFAULT | 300 | 10000 | True | [[1448 185 41 29 71 61 236 37 5... | NORMAL | - | 15 | 196815.26 |
| 46 | [47510450176.0, 10461987.0, 3.044379234313965,... | [0.2487633973360061, 0.3066957294940948, 0.376... | [2.28581562481608e+19, 1.265042553581863e+18, ... | [0.3026405274868011, 0.3697254955768585, 0.433... | [3202.14945268631, 6400.369203567505, 9596.970... | 0 | Bidirectional GRU | 64 | Adam | 0.001 | - | 300 | - | 0.43 | 0 | 0 | LOWER_I | 300 | 10000 | True | [[ 918 165 76 21 2 162 512 56 9... | NORMAL | - | 15 | 67171.45 |
px.bar(
rnn_df,
x='Key',
y="Accuracy",
color='EmbeddingSize',
barmode="group",
facet_col="SeqLen",
facet_row="NumberOfAuthors",
text=rnn_df.Accuracy,
title="Výsledky klasifikace autorů u RNN v závislosti na proměnných"
)
Jak lze pozorovat v mnoha případech síť nedokázala konvergovat k dobrým parametrům, aby mohla rozeznat autora. Můžeme se zde zaměřit pouze na výsledky s velikostí 50 a předzpracování pomocí gensim knihovny. Výsledná přesnost i přes takto složitější model nebyla poznatelně lepší. Bylo by potřeba v rámci projektu provést větší exploraci nad architekturou, aby bylo zjištěno, proč síť nemá tendenci správně konvergovat.
transformer_df = df[df.ModelName == "Transformer"]
transformer_df
| loss | accuracy | val_loss | val_accuracy | time | NaN | ModelName | BatchSize | Optimizer | LR | Epochs | EmbeddingSize | Time | Accuracy | Hits | Miss | Key | SeqLen | VocabSize | TrainableEmbedding | ConfMatrix | Type | TransformerName | NumberOfAuthors | CalculationTime | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 47 | [1.721323847770691, 1.6283942461013794, 1.6182... | [0.167633980512619, 0.200139433145523, 0.20036... | [1.6598328351974487, 1.6358952522277832, 1.622... | [0.1987451016902923, 0.1921568661928177, 0.192... | [12731.715883731842, 25516.930138111115, 38309... | 0 | Transformer | 128 | Adam | 0.001 | - | - | - | 0.20 | 0 | 0 | LOWER_I | 300 | - | True | [[2213 0 0 0 0]\n [2203 0 0 ... | TL | distilbert-base-uncased | 5 | 76558.45 |
| 48 | [0.6788761615753174, 0.2942142188549042, 0.161... | [0.6950744986534119, 0.8956339955329895, 0.945... | [0.4689745604991913, 0.3302323520183563, 0.374... | [0.8304314017295837, 0.8854901790618896, 0.882... | [12792.493296146393, 25582.148589611053, 38391... | 0 | Transformer | 128 | Adam | 5e-05 | - | - | - | 0.88 | 0 | 0 | LOWER_I | 300 | - | True | [[1932 65 99 57 60]\n [ 96 1780 103 ... | TL | distilbert-base-uncased | 5 | 76765.80 |
| 49 | [1.720706582069397, 1.62993323802948, 1.619087... | [0.2695686221122741, 0.1993899792432785, 0.199... | [1.612503170967102, 1.6107159852981567, 1.6184... | [0.2062745094299316, 0.2043921500444412, 0.198... | [25539.273556947708, 51105.00569915772, 76684.... | 0 | Transformer | 128 | Adam | 0.001 | - | - | - | 0.20 | 0 | 0 | LOWER_I | 300 | - | True | [[ 0 0 0 0 2213]\n [ 0 0 0 ... | TL | bert-base-uncased | 5 | 153328.81 |
| 50 | [0.6129876375198364, 0.2475549280643463, 0.137... | [0.7192941308021545, 0.9139520525932312, 0.954... | [0.3863182961940765, 0.3083285987377167, 0.356... | [0.8680784106254578, 0.8953725695610046, 0.892... | [25577.77806091309, 51058.13135123253, 76546.3... | 0 | Transformer | 128 | Adam | 5e-05 | - | - | - | 0.89 | 0 | 0 | LOWER_I | 300 | - | True | [[1845 72 102 118 76]\n [ 34 1847 67 ... | TL | bert-base-uncased | 5 | 153182.25 |
| 51 | [1.7034351825714111, 1.6270469427108765, 1.617... | [0.2704313695430755, 0.198047935962677, 0.1987... | [1.6337881088256836, 1.6143453121185305, 1.611... | [0.2062745094299316, 0.2043921500444412, 0.198... | [8417.139698982239, 16806.370790719986, 25206.... | 0 | Transformer | 128 | Adam | 0.001 | - | - | - | 0.20 | 0 | 0 | LOWER_I | 300 | - | True | [[ 0 0 0 0 2213]\n [ 0 0 0 ... | TL | google/electra-small-discriminator | 5 | 50429.72 |
| 52 | [1.167790174484253, 0.668626070022583, 0.48968... | [0.5210353136062622, 0.760610044002533, 0.8271... | [0.8901610970497131, 0.8246893286705017, 0.576... | [0.6842352747917175, 0.7243921756744385, 0.799... | [8379.385590314865, 16772.64554834366, 25164.1... | 0 | Transformer | 128 | Adam | 5e-05 | - | - | - | 0.81 | 0 | 0 | LOWER_I | 300 | - | True | [[1545 117 274 131 146]\n [ 58 1393 243 ... | TL | google/electra-small-discriminator | 5 | 50316.22 |
px.bar(
transformer_df,
x='Key',
y="Accuracy",
color='TransformerName',
barmode="group",
facet_col="LR",
facet_row="NumberOfAuthors",
text=transformer_df.Accuracy,
title="Výsledky klasifikace autorů u Transformeru v závilosti na proměnných"
)
V experimentech bylo provedeno pouze pár běhů, jelikož transformery při tunění mají tendenci běžet delší dobu z důvodu široké interní reprezentace.
V první řadě byly provedeny běhy s učící konstantou 0.001, která je defaultně nastavena u Adam optimizeru. Vyšlo najevo, že takhle vysoká hodnota pro Transformery není vhod. Síť při tak velké učící konstantně nemá schopnost naučit se správnou interní reprezentaci. Další hodnota byla snížena na mnohem menší číslo, přesně 5e-05. Po této změně již všechny typy transformeru konvergovaly ke kvalitnímu výsledku.
Experimenty byly provedeny na 3 typech transformeru, a to:
Účelem bylo porovnání běhu těchto modelů a zároveň zhodnotit jejich přesnost predikce.
Jak lze tedy z výsledku pozorovat nejlépe dopadl nejkomplexnější a nejstarší model Bert, který překonal "jednoduché neuronové sítě" skoro o 10 procent u 5 autorů. Podobný posun lze možná předpokládat i u více autorů. Hnedka za Bert typem se dostala zjednodušená verze DistilBertu s přesností 88 %. Electra pak dosahovala slabších výsledků 81 %.
px.bar(
transformer_df,
x='Key',
y="CalculationTime",
color='TransformerName',
barmode="group",
facet_col="LR",
facet_row="NumberOfAuthors",
text=transformer_df.CalculationTime,
title="Výsledky klasifikace autorů u Transformeru v závilosti na proměnných"
)
Časová složitost je odstupňována, dle složitost modelu:
Po zhodnocení dostupných dat bychom se pravděpodobně vybrali model DistilBert, běžel vcelku rychle, dosahoval stejné přesnosti jako úplný model Bert. Důležité je myslet na nízkou učící konstantu.
transformer_df
| loss | accuracy | val_loss | val_accuracy | time | NaN | ModelName | BatchSize | Optimizer | LR | Epochs | EmbeddingSize | Time | Accuracy | Hits | Miss | Key | SeqLen | VocabSize | TrainableEmbedding | ConfMatrix | Type | TransformerName | NumberOfAuthors | CalculationTime | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 47 | [1.721323847770691, 1.6283942461013794, 1.6182... | [0.167633980512619, 0.200139433145523, 0.20036... | [1.6598328351974487, 1.6358952522277832, 1.622... | [0.1987451016902923, 0.1921568661928177, 0.192... | [12731.715883731842, 25516.930138111115, 38309... | 0 | Transformer | 128 | Adam | 0.001 | - | - | - | 0.20 | 0 | 0 | LOWER_I | 300 | - | True | [[2213 0 0 0 0]\n [2203 0 0 ... | TL | distilbert-base-uncased | 5 | 76558.45 |
| 48 | [0.6788761615753174, 0.2942142188549042, 0.161... | [0.6950744986534119, 0.8956339955329895, 0.945... | [0.4689745604991913, 0.3302323520183563, 0.374... | [0.8304314017295837, 0.8854901790618896, 0.882... | [12792.493296146393, 25582.148589611053, 38391... | 0 | Transformer | 128 | Adam | 5e-05 | - | - | - | 0.88 | 0 | 0 | LOWER_I | 300 | - | True | [[1932 65 99 57 60]\n [ 96 1780 103 ... | TL | distilbert-base-uncased | 5 | 76765.80 |
| 49 | [1.720706582069397, 1.62993323802948, 1.619087... | [0.2695686221122741, 0.1993899792432785, 0.199... | [1.612503170967102, 1.6107159852981567, 1.6184... | [0.2062745094299316, 0.2043921500444412, 0.198... | [25539.273556947708, 51105.00569915772, 76684.... | 0 | Transformer | 128 | Adam | 0.001 | - | - | - | 0.20 | 0 | 0 | LOWER_I | 300 | - | True | [[ 0 0 0 0 2213]\n [ 0 0 0 ... | TL | bert-base-uncased | 5 | 153328.81 |
| 50 | [0.6129876375198364, 0.2475549280643463, 0.137... | [0.7192941308021545, 0.9139520525932312, 0.954... | [0.3863182961940765, 0.3083285987377167, 0.356... | [0.8680784106254578, 0.8953725695610046, 0.892... | [25577.77806091309, 51058.13135123253, 76546.3... | 0 | Transformer | 128 | Adam | 5e-05 | - | - | - | 0.89 | 0 | 0 | LOWER_I | 300 | - | True | [[1845 72 102 118 76]\n [ 34 1847 67 ... | TL | bert-base-uncased | 5 | 153182.25 |
| 51 | [1.7034351825714111, 1.6270469427108765, 1.617... | [0.2704313695430755, 0.198047935962677, 0.1987... | [1.6337881088256836, 1.6143453121185305, 1.611... | [0.2062745094299316, 0.2043921500444412, 0.198... | [8417.139698982239, 16806.370790719986, 25206.... | 0 | Transformer | 128 | Adam | 0.001 | - | - | - | 0.20 | 0 | 0 | LOWER_I | 300 | - | True | [[ 0 0 0 0 2213]\n [ 0 0 0 ... | TL | google/electra-small-discriminator | 5 | 50429.72 |
| 52 | [1.167790174484253, 0.668626070022583, 0.48968... | [0.5210353136062622, 0.760610044002533, 0.8271... | [0.8901610970497131, 0.8246893286705017, 0.576... | [0.6842352747917175, 0.7243921756744385, 0.799... | [8379.385590314865, 16772.64554834366, 25164.1... | 0 | Transformer | 128 | Adam | 5e-05 | - | - | - | 0.81 | 0 | 0 | LOWER_I | 300 | - | True | [[1545 117 274 131 146]\n [ 58 1393 243 ... | TL | google/electra-small-discriminator | 5 | 50316.22 |
def create_dfs(selector):
dfs = []
for i in range(len(transformer_df)):
row = transformer_df.iloc[i, :]
column = row[selector]
name = row.TransformerName + " " + row.LR
current_df = pd.DataFrame()
current_df['Type'] = len(column) * [name]
current_df['Value'] = column
dfs.append(current_df)
return dfs
df_val_loss = pd.concat(create_dfs('val_loss'))
df_val_acc = pd.concat(create_dfs('accuracy'))
px.line(df_val_loss, y='Value', color='Type', title="Chyba v transformerech")
px.line(df_val_acc, y='Value', color='Type', title="Přesnost v transformerech")
Na grafech lze vidět zmíněnou stagnaci modelu při vysoké učící konstantě a zároveň fakt, že Electra by možná mohla dosáhnout lepších výsledků, pakliže by mohlo dále běžet učení model. Výsledkem by mohlo být překonání Berta.
bert_confusion_df = transformer_df[transformer_df.TransformerName == "distilbert-base-uncased"]
bert_confusion_df = bert_confusion_df[bert_confusion_df.LR == "5e-05"]
bert_confusion = bert_confusion_df.ConfMatrix
def parse_confusion(conf):
conf = conf[1:len(conf)-1]
splited = conf.split('\n')
matrix = []
for row in splited:
row = row.strip()
row = row[1:len(row)-1]
row = strip_multiple_whitespaces_gensim(row).strip()
numbers = [int(n) for n in row.split(' ')]
matrix.append(numbers)
return np.array(matrix)
conf = parse_confusion(bert_confusion.values[0])
plt.figure(figsize=[20, 20])
sns.heatmap(conf, annot=True, fmt='g', cmap='Blues')
<AxesSubplot:>
Z vizualizované matice záměn lze pozorovat, že často dochází k záměně u 1 proti 2 a 3.
df_5 = df[df.NumberOfAuthors == "5"].reset_index()
df_5 = df_5.sort_values(by='Accuracy', ascending=False)
df_5['ConstructedModelName'] = df_5.ModelName + df_5.TransformerName
df_5 = df_5.iloc[0:5, :]
px.bar(
df_5,
x='ConstructedModelName',
y="Accuracy",
text=df_5.Accuracy,
title="Nejlepších 5 modelů u klasifikace 5 autorů",
color='ConstructedModelName'
)
Na grafu lze vidět, že jednoduchá hluboká neuronová síť dosáhla stejného výsledku jako Electra. Bert s 100m parametrů, pak zcela kraluje všem ostatním modelům.
px.bar(
df_5,
x='ConstructedModelName',
y="Accuracy",
text=df_5.Accuracy,
title="Nejlepších 5 modelů u klasifikace 5 autorů",
color='ConstructedModelName'
)
px.bar(
df_5,
x='ConstructedModelName',
y="CalculationTime",
text=df_5.CalculationTime,
title="Časová náročnost 5 modelů u klasifikace 5 autorů",
color='ConstructedModelName'
)
Je nutné pozorovat, že vyšší přesnost je za cenu mnohem vyšší časové složitosti. V tomto případě by šlo aplikovat paralelizaci na transformer modely, ale i přesto náročnost na výpočet lze pozorovat mnohem vyšší. Proto je nutné vždy zvážit, zda chceme benefitovat ze zvýšené přesnosti i na úkor tak vysoké složitosti.
df_15 = df[df.NumberOfAuthors == "15"].reset_index()
df_15 = df_15.sort_values(by='Accuracy', ascending=False)
df_15['ConstructedModelName'] = df_15.ModelName + df_15.TransformerName + [str(x) for x in df_15.index]
df_15 = df_15.iloc[0:5, :]
df_15
| index | loss | accuracy | val_loss | val_accuracy | time | NaN | ModelName | BatchSize | Optimizer | LR | Epochs | EmbeddingSize | Time | Accuracy | Hits | Miss | Key | SeqLen | VocabSize | TrainableEmbedding | ConfMatrix | Type | TransformerName | NumberOfAuthors | CalculationTime | ConstructedModelName | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 2 | 10 | [2.0945076942443848, 1.5407600402832031, 1.260... | [0.3058666586875915, 0.4837879538536072, 0.582... | [1.6695539951324463, 1.3899223804473877, 1.346... | [0.4403137266635895, 0.5398169755935669, 0.576... | [55.19337034225464, 109.7722339630127, 164.612... | 0 | DENSE | 64 | Adam | 0.001 | - | 50 | - | 0.58 | 0 | 0 | DEFAULT | 200 | 10000 | True | [[1521 213 24 13 44 56 146 17 9... | NORMAL | - | 15 | 1143.94 | DENSE-2 |
| 3 | 11 | [2.0426511764526367, 1.4889464378356934, 1.221... | [0.331916332244873, 0.4952156841754913, 0.5902... | [1.6068023443222046, 1.3842854499816897, 1.342... | [0.4554771184921264, 0.5311895608901978, 0.570... | [57.66135025024414, 114.47401738166808, 171.94... | 0 | DENSE | 64 | Adam | 0.001 | - | 50 | - | 0.58 | 0 | 0 | LOWER_I | 200 | 10000 | True | [[1535 153 14 8 31 47 290 53 11... | NORMAL | - | 15 | 1197.03 | DENSE-3 |
| 21 | 45 | [3889620.75, 5.671187400817871, 174.6585388183... | [0.2653124034404754, 0.3765461146831512, 0.458... | [74537.96875, 1.6551882028579712, 3.5518321990... | [0.3310849666595459, 0.4476862847805023, 0.468... | [3747.2181475162506, 7466.645300388336, 11197.... | 0 | Bidirectional GRU | 64 | Adam | 0.001 | - | 300 | - | 0.57 | 0 | 0 | DEFAULT | 300 | 10000 | True | [[1448 185 41 29 71 61 236 37 5... | NORMAL | - | 15 | 196815.26 | Bidirectional GRU-21 |
| 11 | 27 | [1.9838707447052, 1.4141194820404053, 1.121265... | [0.3484601378440857, 0.5100711584091187, 0.607... | [1.531771898269653, 1.3959068059921265, 1.3805... | [0.4722614288330078, 0.5345359444618225, 0.558... | [122.99041819572447, 245.43910098075867, 367.5... | 0 | DENSE | 64 | Adam | 0.001 | - | 300 | - | 0.56 | 0 | 0 | LOWER_I | 200 | 10000 | True | [[1257 309 28 8 77 50 398 8 3... | NORMAL | - | 15 | 2569.64 | DENSE-11 |
| 20 | 44 | [4.139690399169922, 31051933696.0, 5946009.0, ... | [0.2126849740743637, 0.3071604967117309, 0.362... | [45.60334777832031, 1.8750523328781128, 1.7882... | [0.2989281117916107, 0.3549803793430328, 0.382... | [4035.43758225441, 7763.776393175125, 11458.88... | 0 | Bidirectional GRU | 64 | Adam | 0.001 | - | 300 | - | 0.56 | 0 | 0 | LOWER | 300 | 10000 | True | [[1488 176 59 4 126 19 158 42 5... | NORMAL | - | 15 | 207936.63 | Bidirectional GRU-20 |
px.bar(
df_15,
x='ConstructedModelName',
y="Accuracy",
text=df_15.Accuracy,
title="Nejlepších 5 modelů u klasifikace 15 autorů",
color='ConstructedModelName'
)
Rozdíl proti 5 autorům je dost poznatelný, jedná se o 30 %. Nutné je podotknout, že v těchto případech model má větší náchylnost na zmatení stylem ostatních autorů.
px.bar(
df_15,
x='ConstructedModelName',
y="CalculationTime",
text=df_15.CalculationTime,
title="Časová náročnost 5 modelů u klasifikace 15 autorů",
color='ConstructedModelName'
)
Časová náročnost je násobně vyšší u oboustranné RNN sítě. Důvodem je, že velikost vstupní sekvence je 200 a 300. Zpracování probíhá slovo po slovu, a proto trvání je takové.
Hluboká neuronová síť dosahuje celkem vysokých přesností, i když se jedná o tak jednoduchý model. Embedding vrstva odvádí skvělou práci, kdy je schopná se velice efektivně i přes jednoduchý model naučit číselnou reprezentaci pro každé slovo.
RNN nedosahovala dobrých výsledků, bylo by potřeba provést více experimentální činnosti.
Transformer model má vysokou časovou náročnost a je nutné specifikovat nízkou učící konstantu.
Transformer model dosáhl nejlepších výsledků u 5 autorů a to 89 procent.
Transformer model stačí malé množství epoch k tomu, aby byl schopný vyřešit náš definovaný problém.
Electra model umožňuje delší učení.